# Tutorials
> This bundle contains all pages in the Tutorials section.
> Source: https://www.union.ai/docs/v2/union/tutorials/

=== PAGE: https://www.union.ai/docs/v2/union/tutorials ===

# Tutorials

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

This section contains tutorials that showcase relevant use cases and provide step-by-step instructions on how to implement various features using Flyte and Union. Tutorials are organized by **industry vertical** and by **technical topic**.

## Industry verticals

### **Biotech & healthcare**

Bioinformatics, medical imaging, and other life-sciences workloads.

### **Geospatial**

Satellite imagery, remote sensing, and earth and atmospheric modeling workloads.

### **Financial services & fintech**

Financial research, trading, and other fintech workloads.

### **Frontier AI**

Frontier-model pretraining, automated experimentation, and large-scale AI workloads.

## Technical topics

### **Computer vision**

Image and vision-language model workloads.

### **Agents**

Agentic workflows and autonomous LLM-powered systems.

### **Context engineering**

Prompt engineering, prompt optimization, and context construction.

### **Model training**

Training, fine-tuning, and hyperparameter optimization of models at scale.

### **Data processing**

Large-scale data processing and batching strategies.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare ===

# Biotech & healthcare

Tutorials for bioinformatics, medical imaging, and other life-sciences workloads.

### **Biotech & healthcare > Genomic alignment**

Align sequencing reads to a reference genome with a cached, parallel Bowtie 2 pipeline.

### **Biotech & healthcare > Cross-species gene comparison**

Compare homologous genes across species with Carbon scoring, sequence alignment, and ESMFold 3D structures.

### **Biotech & healthcare > Genomic variant effect prediction**

Zero-shot pathogenicity scoring with HuggingFace Carbon and interactive VEP reports.

### **Biotech & healthcare > Brain tumor MRI classification**

Classify brain MRI scans with a two-phase EfficientNet-B4 pipeline featuring resumable GPU checkpointing and in-UI reports.

### **Biotech & healthcare > Drug molecule screening agent**

Agentic virtual screening with RDKit stage tools, Lipinski filters, and ranked drug-likeness reports.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/genomic-alignment ===

# Genomic alignment

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_alignment).

This tutorial builds a bioinformatics pipeline that aligns raw sequencing reads to a reference genome. The workflow downloads a reference genome and paired-end sequencing data, performs quality filtering, builds a reference index, and aligns the filtered reads with the [Bowtie 2](https://bowtie-bio.sourceforge.net/bowtie2/index.shtml) aligner — running each sample in parallel.

It's a good showcase of how Flyte handles real bioinformatics workloads:

- **Per-task resources** so quality filtering, indexing, and alignment each request exactly the CPU and memory they need.
- **`cache="auto"`** on the download and indexing steps, so re-runs skip work that hasn't changed.
- **Fan-out parallelism** across samples with `asyncio.gather`.
- **System dependencies** (`fastp`, `bowtie2`) installed into the container image with `apt`.

## Define the container image

Because the pipeline shells out to bioinformatics tools, we build a custom image with `flyte.Image.from_uv_script` and install `fastp` (quality filtering) and `bowtie2` (alignment) via `apt`.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

The Python dependencies are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# ///
```

## Define the task environments

Each stage runs in its own `TaskEnvironment` with tailored resources. The top-level `base_env` declares the others as `depends_on` so the tasks it calls are available at run time.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Define the data classes

We model the reference genome, sequencing reads, and alignment results as dataclasses. `flyte.io.File` and `flyte.io.Dir` reference offloaded data in blob storage, so large genomic files are passed between tasks by reference rather than copied through the orchestrator.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Fetch assets

The first task downloads the reference genome and paired-end reads from remote URLs and materializes them as `File`/`Dir` objects. It's cached, so repeat runs skip the download.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Quality filtering with fastp

`pyfastp` removes duplicate and low-quality reads. It requests extra memory so it can process larger read files efficiently.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Build the Bowtie 2 index

A reference index rarely changes, so this task is cached.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Align reads

Each sample is aligned to the indexed reference with Bowtie 2, producing a SAM file.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Orchestrate the workflow

The top-level task fetches the assets, filters every sample in parallel, builds the index, and aligns all samples. Parallelism across samples is achieved with `asyncio.gather` rather than a separate `@dynamic` decorator.

```
# # Genomic Alignment
#
# This tutorial demonstrates how to use Flyte to build a workflow that
# performs genomic alignment on sequencing data. The workflow takes as input
# a reference genome and raw sequencing data, performs quality filtering and
# preprocessing on the raw data, generates an index for the reference genome,
# and aligns the filtered data to the reference genome using the Bowtie 2 aligner.

# {{run-on-union}}

# The tutorial is divided into the following sections:
# 1. Define the container image
# 2. Define the data classes
# 3. Define the tasks
# 4. Define the workflow

# /// script
# requires-python = "3.12"
# dependencies = [
#    "flyte",
#    "requests",
# ]
# main = "alignment_wf"
# params = ""
# ///

import asyncio
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import requests
import flyte
from flyte.io import Dir, File

# ## Defining a Container Image
#
# We define a custom container image using `flyte.Image`. Since we need bioinformatics
# tools — `fastp` for quality filtering and `bowtie2` for alignment — we install them
# via apt. This approach replaces the v1 `ImageSpec` with conda channels.

# {{docs-fragment image}}
main_img = (
    flyte.Image.from_uv_script(
        __file__,
        name="alignment-tutorial",
    )
    .with_apt_packages("fastp", "bowtie2")
)
# {{/docs-fragment image}}

# We define per-task environments with different resource requirements, then a
# top-level `base_env` that declares all of them as dependencies (required because
# `alignment_wf` and `bowtie2_align_samples` call tasks that live in those environments).

# {{docs-fragment envs}}
fetch_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fetch",
    image=main_img,
    cache="auto",
)

fastp_env = flyte.TaskEnvironment(
    name="alignment-tutorial-fastp",
    image=main_img,
    resources=flyte.Resources(memory="2Gi"),
)

index_env = flyte.TaskEnvironment(
    name="alignment-tutorial-index",
    image=main_img,
    resources=flyte.Resources(memory="10Gi"),
    cache="auto",
)

align_env = flyte.TaskEnvironment(
    name="alignment-tutorial-align",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="10Gi"),
)

base_env = flyte.TaskEnvironment(
    name="alignment-tutorial",
    image=main_img,
    depends_on=[fetch_env, fastp_env, index_env, align_env],
)
# {{/docs-fragment envs}}

# ## Defining Data Classes
#
# We define three data classes to represent the reference genome, sequencing reads,
# and alignment results. We'll first define a convenience function to download files,
# which we'll use within the fetch task to materialize assets from their remote locations.

def fetch_file(url: str, local_dir: str) -> Path:
    """
    Downloads a file from the specified URL.

    Args:
        url (str): The URL of the file to download.
        local_dir (str): The directory where you would like this file saved.

    Returns:
        Path: The local path to the file.

    Raises:
        requests.HTTPError: If an HTTP error occurs while downloading the file.
    """
    url_parts = url.split("/")
    fname = url_parts[-1]
    local_path = Path(local_dir) / fname

    response = requests.get(url)
    with open(local_path, "wb") as file:
        file.write(response.content)

    return local_path

# Reference genomes are used extensively throughout bioinformatics workflows. We define a
# `Reference` data class to represent a reference genome and its associated index files.

# {{docs-fragment dataclasses}}
@dataclass
class Reference:
    """
    Represents a reference FASTA and associated index files.

    Attributes:
        ref_name (str): Name or identifier of the reference file.
        ref_dir (Dir): Directory containing the reference and any index files.
        index_name (str): Index string to pass to tools requiring it.
        indexed_with (str): Name of tool used to create the index.
    """

    ref_name: str
    ref_dir: Dir
    index_name: str | None = None
    indexed_with: str | None = None

# Sequencing reads are the raw data generated from a sequencing experiment.

@dataclass
class Reads:
    """
    Represents a sequencing reads sample via its associated FastQ files.

    Attributes:
        sample (str): The name or identifier of the raw sequencing sample.
        read1 (File): A File object representing the path to the raw R1 read file.
        read2 (File): A File object representing the path to the raw R2 read file.
    """

    sample: str
    read1: File | None = None
    read2: File | None = None

    def get_read_fnames(self):
        return (
            f"{self.sample}_1.fastq.gz",
            f"{self.sample}_2.fastq.gz",
        )

# Finally, we define an `Alignment` data class to represent an alignment file.

@dataclass
class Alignment:
    """
    Represents an alignment file and its associated sample.

    Attributes:
        sample (str): The name or identifier of the sample.
        aligner (str): The name of the aligner used to generate the alignment file.
        format (str): The format of the alignment file (e.g., SAM, BAM).
        alignment (File): A File object representing the path to the alignment file.
    """

    sample: str
    aligner: str
    format: str | None = None
    alignment: File | None = None

    def get_alignment_fname(self):
        return f"{self.sample}_{self.aligner}_aligned.{self.format}"
# {{/docs-fragment dataclasses}}

# ## Tasks
#
# We define a series of tasks to perform the following operations:
# 1. Fetch assets from remote URLs
# 2. Perform quality filtering and preprocessing using FastP
# 3. Generate Bowtie2 index files from a reference genome
# 4. Perform alignment using Bowtie2 on a filtered sample
#
# The first task fetches the reference genome and sequencing reads. It is cached
# so that re-runs skip the download step.

# {{docs-fragment fetch_assets}}
@fetch_env.task
async def fetch_assets(
    ref_url: str, read_urls: List[str]
) -> tuple[Reference, List[Reads]]:
    """
    Fetch assets from remote URLs.
    """
    # Download reference genome
    ref_dir = Path("/tmp/reference_genome")
    ref_dir.mkdir(exist_ok=True, parents=True)
    ref = fetch_file(ref_url, str(ref_dir))
    ref_obj = Reference(
        ref_name=ref.name,
        ref_dir=await Dir.from_local(str(ref_dir)),
    )

    # Download sequencing reads
    dl_loc = Path("/tmp/reads")
    dl_loc.mkdir(exist_ok=True, parents=True)

    samples: dict[str, Reads] = {}
    for url in read_urls:
        fp = fetch_file(url, str(dl_loc))
        sample = fp.stem.split("_")[0]

        if sample not in samples:
            samples[sample] = Reads(sample=sample)

        if ".fastq.gz" in fp.name or "fasta" in fp.name:
            mate = fp.name.strip(".fastq.gz").strip(".filt").split("_")[-1]
            if "1" in mate:
                samples[sample].read1 = await File.from_local(str(fp))
            elif "2" in mate:
                samples[sample].read2 = await File.from_local(str(fp))

    return ref_obj, list(samples.values())
# {{/docs-fragment fetch_assets}}

# The second task performs quality filtering and preprocessing using FastP on a Reads object.
# FastP is a performant tool for removing duplicate or low-quality reads. We increase
# the memory request for this task so FastP can efficiently process reads from larger files.

# {{docs-fragment pyfastp}}
@fastp_env.task
async def pyfastp(rs: Reads) -> Reads:
    """
    Perform quality filtering and preprocessing using Fastp on a Reads object.

    Args:
        rs (Reads): A Reads object containing raw sequencing data to be processed.

    Returns:
        Reads: A Reads object representing the filtered and preprocessed data.
    """
    ldir = Path(tempfile.mkdtemp())
    samp = Reads(rs.sample)
    o1, o2 = samp.get_read_fnames()
    o1p = ldir / o1
    o2p = ldir / o2

    assert rs.read1 is not None and rs.read2 is not None
    r1 = await rs.read1.download()
    r2 = await rs.read2.download()

    cmd = [
        "fastp",
        "-i", str(r1),
        "-I", str(r2),
        "-o", str(o1p),
        "-O", str(o2p),
    ]
    subprocess.run(cmd, check=True)

    samp.read1 = await File.from_local(str(o1p))
    samp.read2 = await File.from_local(str(o2p))

    return samp
# {{/docs-fragment pyfastp}}

# Next, we define a task to generate Bowtie2 index files from a reference genome. As the index
# for a given tool and reference seldom changes, we cache this task.

# {{docs-fragment bowtie2_index}}
@index_env.task
async def bowtie2_index(ref: Reference) -> Reference:
    """
    Generate Bowtie2 index files from a reference genome.

    Args:
        ref (Reference): A Reference object representing the reference genome.

    Returns:
        Reference: The same reference object with the index_name and indexed_with attributes set.
    """
    ref_dir = await ref.ref_dir.download()
    idx_name = "bt2_idx"
    cmd = [
        "bowtie2-build",
        str(Path(str(ref_dir)) / ref.ref_name),
        str(Path(str(ref_dir)) / idx_name),
    ]
    subprocess.run(cmd, check=True)
    return Reference(
        ref.ref_name,
        await Dir.from_local(str(ref_dir)),
        idx_name,
        "bowtie2",
    )
# {{/docs-fragment bowtie2_index}}

# The next task performs paired-end alignment using Bowtie 2 on a single sample.

# {{docs-fragment bowtie2_align}}
@align_env.task
async def bowtie2_align_paired_reads(idx: Reference, fs: Reads) -> Alignment:
    """
    Perform paired-end alignment using Bowtie 2 on a filtered sample.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        fs (Reads): A filtered Reads object containing sample data to be aligned.

    Returns:
        Alignment: An Alignment object representing the alignment result.
    """
    assert idx.indexed_with == "bowtie2", "Reference index must be generated with bowtie2"
    assert idx.index_name is not None
    assert fs.read1 is not None and fs.read2 is not None

    ref_dir = await idx.ref_dir.download()
    r1 = await fs.read1.download()
    r2 = await fs.read2.download()

    ldir = Path(tempfile.mkdtemp())
    alignment = Alignment(fs.sample, "bowtie2", "sam")
    al = ldir / alignment.get_alignment_fname()

    cmd = [
        "bowtie2",
        "-x", str(Path(str(ref_dir)) / idx.index_name),
        "-1", str(r1),
        "-2", str(r2),
        "-S", str(al),
    ]
    subprocess.run(cmd, check=True)

    alignment.alignment = await File.from_local(str(al))
    return alignment
# {{/docs-fragment bowtie2_align}}

# In place of the v1 `@dynamic` workflow, we use a plain async task with `asyncio.gather`
# to run alignments for all samples in parallel.

@base_env.task
async def bowtie2_align_samples(
    idx: Reference, samples: List[Reads]
) -> List[Alignment]:
    """
    Process samples through bowtie2 in parallel.

    Args:
        idx (Reference): A Reference object containing the Bowtie 2 index.
        samples (List[Reads]): A list of Reads objects to be aligned.

    Returns:
        List[Alignment]: A list of Alignment objects representing the alignment results.
    """
    tasks = [bowtie2_align_paired_reads(idx=idx, fs=sample) for sample in samples]
    return list(await asyncio.gather(*tasks))

# ## End-to-End Workflow
#
# We tie everything together in a final task that fetches assets, filters them, generates
# an index, and aligns the samples. In place of the v1 `@workflow`, we use a top-level
# `@base_env.task`. Parallelism across samples is achieved with `asyncio.gather`.

# {{docs-fragment workflow}}
@base_env.task
async def alignment_wf() -> List[Alignment]:
    # Prepare raw samples from remote URLs
    ref, samples = await fetch_assets(
        ref_url="https://github.com/unionai-oss/unionbio/raw/main/tests/assets/references/GRCh38_short.fasta",
        read_urls=[
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_1.fastq.gz",
            "https://github.com/unionai-oss/unionbio/raw/main/tests/assets/sequences/raw/ERR250683-tiny_2.fastq.gz",
        ],
    )

    # Filter all samples in parallel
    filtered_samples = list(
        await asyncio.gather(*[pyfastp(rs=s) for s in samples])
    )

    # Generate a bowtie2 index or load it from cache
    bowtie2_idx = await bowtie2_index(ref=ref)

    # Generate alignments using bowtie2
    sams = await bowtie2_align_samples(idx=bowtie2_idx, samples=filtered_samples)

    return sams
# {{/docs-fragment workflow}}

# You can now run the workflow using the command in the dropdown at the top of the page!

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(alignment_wf)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_alignment/genomic_alignment.py*

## Run the workflow

This example has no secrets or external API keys — it pulls public test data from GitHub.

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_alignment), run it as a `uv` script:

```
cd v2/tutorials/genomic_alignment
uv run --script genomic_alignment.py
```

Or submit it with the Flyte CLI:

```
flyte run genomic_alignment.py alignment_wf
```

When the run completes, each returned `Alignment` points to a SAM file in blob storage that you can download from the run's outputs in the UI.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/tumor-detection ===

# Brain tumor MRI classification

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/tumor_detection).

This tutorial builds a medical-imaging pipeline that classifies brain MRI scans into four categories — Glioma, Meningioma, No Tumor, and Pituitary — using a two-phase EfficientNet-B4 transfer-learning strategy. The pipeline downloads the dataset, trains on a GPU with fault-tolerant checkpointing, and renders training curves and a confusion matrix directly in the Union.ai UI.

The example is split into focused modules:

- `config.py` — container image, task environments, and the `TrainingConfig` hyperparameters.
- `dataset.py` — downloads the Hugging Face dataset, builds class-balanced data loaders.
- `model.py` / `training.py` — the Lightning module and the two-phase training loop.
- `utils.py` — plotting helpers for the report.
- `run.py` — the three Flyte tasks and the pipeline driver.

Flyte handles the production concerns:

- **Per-task resources**: CPU for download/reporting, a GPU for training.
- **`cache="auto"`** on dataset download and training, so reruns with the same data and config are free.
- **`retries=3`** plus **Flyte checkpointing** on the training task so a preempted GPU job resumes from the last epoch.
- **Built-in reports** to visualize metrics without separate dashboard infrastructure.

## Define the container image

A single GPU-ready image is shared by all tasks. `with_source_folder` copies the local modules (`dataset.py`, `model.py`, etc.) into the image.

```
"""
Configuration for brain tumor MRI classification pipeline.

Defines task environments, resource requirements, and training hyperparameters.
"""

import pathlib

import flyte

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(
    name="tumor_detection_gpu"
).with_pip_packages(
    "torch",
    "lightning",
    "torchvision",
    "timm",
    "pillow",
    "scikit-learn",
    "plotly",
    "numpy",
    "pandas",
    "torchmetrics",
    "datasets",
    "typing_extensions",
).with_source_folder(
    pathlib.Path(__file__).parent,
    copy_contents_only=True,
)
# {{/docs-fragment image}}

# {{docs-fragment envs}}
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
    name="tumor_dataset",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
    cache="auto",
)

# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
    name="tumor_gpu_training",
    image=image,
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    env_vars={
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_LAUNCH_BLOCKING": "1",
        "TORCH_CUDA_MEMORY_FRACTION": "1.0",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
    cache="auto",
)

# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
    name="tumor_report",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)

# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
    name="tumor_pipeline",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, report_env],
)
# {{/docs-fragment envs}}

class TrainingConfig:
    """Unified training configuration for brain tumor MRI classification."""

    def __init__(
        self,
        image_size: int = 380,
        num_classes: int = 4,
        model_name: str = "efficientnet_b4",
        pretrained: bool = True,
        phase1_epochs: int = 8,
        phase1_lr: float = 1e-3,
        phase1_freeze_backbone: bool = True,
        phase2_epochs: int = 25,
        phase2_lr: float = 5e-5,
        batch_size: int = 16,
        num_workers: int = 0,
        val_split: float = 0.2,
        weight_decay: float = 1e-4,
        warmup_steps: int = 200,
        focal_gamma: float = 2.0,
        mixup_alpha: float = 0.0,
        log_interval: int = 50,
    ):
        self.image_size = image_size
        self.num_classes = num_classes

        self.model_name = model_name
        self.pretrained = pretrained

        self.phase1_epochs = phase1_epochs
        self.phase1_lr = phase1_lr
        self.phase1_freeze_backbone = phase1_freeze_backbone

        self.phase2_epochs = phase2_epochs
        self.phase2_lr = phase2_lr

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split

        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps

        self.focal_gamma = focal_gamma
        self.mixup_alpha = mixup_alpha

        self.log_interval = log_interval

    def to_dict(self) -> dict:
        return self.__dict__
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/config.py*

## Define the task environments

Each stage declares the resources it needs. The lightweight `pipeline_env` orchestrates the others via `depends_on`.

```
"""
Configuration for brain tumor MRI classification pipeline.

Defines task environments, resource requirements, and training hyperparameters.
"""

import pathlib

import flyte

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(
    name="tumor_detection_gpu"
).with_pip_packages(
    "torch",
    "lightning",
    "torchvision",
    "timm",
    "pillow",
    "scikit-learn",
    "plotly",
    "numpy",
    "pandas",
    "torchmetrics",
    "datasets",
    "typing_extensions",
).with_source_folder(
    pathlib.Path(__file__).parent,
    copy_contents_only=True,
)
# {{/docs-fragment image}}

# {{docs-fragment envs}}
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
    name="tumor_dataset",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
    cache="auto",
)

# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
    name="tumor_gpu_training",
    image=image,
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    env_vars={
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_LAUNCH_BLOCKING": "1",
        "TORCH_CUDA_MEMORY_FRACTION": "1.0",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
    cache="auto",
)

# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
    name="tumor_report",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)

# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
    name="tumor_pipeline",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, report_env],
)
# {{/docs-fragment envs}}

class TrainingConfig:
    """Unified training configuration for brain tumor MRI classification."""

    def __init__(
        self,
        image_size: int = 380,
        num_classes: int = 4,
        model_name: str = "efficientnet_b4",
        pretrained: bool = True,
        phase1_epochs: int = 8,
        phase1_lr: float = 1e-3,
        phase1_freeze_backbone: bool = True,
        phase2_epochs: int = 25,
        phase2_lr: float = 5e-5,
        batch_size: int = 16,
        num_workers: int = 0,
        val_split: float = 0.2,
        weight_decay: float = 1e-4,
        warmup_steps: int = 200,
        focal_gamma: float = 2.0,
        mixup_alpha: float = 0.0,
        log_interval: int = 50,
    ):
        self.image_size = image_size
        self.num_classes = num_classes

        self.model_name = model_name
        self.pretrained = pretrained

        self.phase1_epochs = phase1_epochs
        self.phase1_lr = phase1_lr
        self.phase1_freeze_backbone = phase1_freeze_backbone

        self.phase2_epochs = phase2_epochs
        self.phase2_lr = phase2_lr

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split

        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps

        self.focal_gamma = focal_gamma
        self.mixup_alpha = mixup_alpha

        self.log_interval = log_interval

    def to_dict(self) -> dict:
        return self.__dict__
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/config.py*

## Configure training

Hyperparameters are gathered in a single `TrainingConfig`, serialized to JSON, and passed into the training task so the exact configuration is captured alongside the run.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Load the dataset

The first task downloads the public [Brain Tumor MRI dataset](https://huggingface.co/datasets/AIOmarRehan/Brain_Tumor_MRI_Dataset) from Hugging Face (no auth required) and stores it as a `flyte.io.Dir`. It's cached, so subsequent runs reuse it.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Train the model

The training task downloads the dataset `Dir`, runs two-phase training (frozen backbone, then full fine-tuning), and writes metrics and predictions to an output `Dir`. It sets `retries=3` so a preempted GPU node restarts the task.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

### Resumable checkpointing

To make retries cheap, training mirrors its Lightning checkpoint directory to a `flyte.Checkpoint` after every epoch, and resumes from the latest checkpoint when the task restarts.

```
"""
Training pipeline for brain tumor MRI classification.

Implements two-phase training:
- Phase 1: Frozen backbone (feature extractor), train classification head
- Phase 2: Fine-tune full model with differential LRs + cosine annealing
"""

from config import TrainingConfig
from dataset import compute_class_weights, create_data_loaders
from model import TumorClassifierLightningModule
from utils import get_model_size, get_trainable_params

def train_tumor_classifier(
    config: TrainingConfig,
    dataset_path: str,
) -> dict:
    """
    Run two-phase training on the preprocessed dataset and return metrics + final predictions.

    dataset_path: local directory where the flyte.io.Dir was downloaded by the training task.
    """
    import pathlib

    import flyte
    import lightning as L
    import torch
    from lightning.pytorch.callbacks import ModelCheckpoint
    from typing_extensions import override

    # {{docs-fragment flyte_checkpoint}}
    class FlyteLightningCheckpointCallback(ModelCheckpoint):
        """Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""

        def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
            super().__init__(dirpath=dirpath, **kwargs)
            self._flyte_checkpoint = flyte_checkpoint

        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            super().on_train_epoch_end(trainer, pl_module)
            if self.dirpath:
                self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
    # {{/docs-fragment flyte_checkpoint}}

    class MetricsLoggerCallback(L.Callback):
        def __init__(self, phase1_epochs: int):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.history = []

        def on_validation_epoch_end(self, trainer, _pl_module):
            epoch = trainer.current_epoch
            metrics = trainer.callback_metrics
            self.history.append({
                "epoch": epoch,
                "phase": 1 if epoch < self.phase1_epochs else 2,
                "train_loss": float(metrics.get("train/loss_epoch", 0)),
                "val_loss": float(metrics.get("val/loss", 0)),
                "val_acc": float(metrics.get("val/acc", 0)),
                "macro_f1": float(metrics.get("val/macro_f1", 0)),
            })

    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True

    print("\n" + "=" * 80)
    print("BRAIN TUMOR MRI CLASSIFICATION WITH EFFICIENTNET-B4")
    print("=" * 80)
    print(f"Config: {config.to_dict()}\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    print("\nLoading MRI images...")
    train_loader, val_loader = create_data_loaders(
        dataset_path=dataset_path,
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        val_split=config.val_split,
    )
    print(f"Data loaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")

    print("\nComputing class weights for focal loss...")
    class_weights = compute_class_weights(dataset_path)
    print(f"Class weights: {class_weights.tolist()}")

    # Per-class gamma: Meningioma gets 7.0, all others 3.0.
    # CLASS_NAMES alphabetical order: Glioma=0, Meningioma=1, No Tumor=2, Pituitary=3
    gamma_per_class = torch.tensor([3.0, 7.0, 3.0, 3.0])

    print("\nInitializing model...")
    model = TumorClassifierLightningModule(
        num_classes=config.num_classes,
        model_name=config.model_name,
        pretrained=config.pretrained,
        learning_rate=config.phase1_lr,
        freeze_backbone=config.phase1_freeze_backbone,
        weight_decay=config.weight_decay,
        warmup_steps=config.warmup_steps,
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        focal_gamma=config.focal_gamma,
        mixup_alpha=config.mixup_alpha,
        class_weights=class_weights,
        gamma_per_class=gamma_per_class,
    )

    print(f"Model: {config.model_name}")
    print(f"Total parameters: {get_model_size(model.model):,}")
    print(f"Trainable parameters: {get_trainable_params(model.model):,}")

    from pathlib import Path
    checkpoint_dir = Path("/tmp/tumor_checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # {{docs-fragment resume}}
    # --- Flyte checkpoint: resume from previous attempt if one exists ---
    resume_ckpt: str | None = None
    ctx = flyte.ctx()
    flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None

    if flyte_checkpoint:
        prev_path = flyte_checkpoint.load_sync()
        if prev_path:
            last = flyte.latest_checkpoint(prev_path)
            if last:
                ck = torch.load(str(last), map_location="cpu", weights_only=False)
                epoch_start = int(ck.get("epoch", 0))
                resume_ckpt = str(last)
                print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
    # --------------------------------------------------------------------
    # {{/docs-fragment resume}}

    metrics_cb = MetricsLoggerCallback(phase1_epochs=config.phase1_epochs)

    resume_callback = (
        FlyteLightningCheckpointCallback(
            flyte_checkpoint,
            dirpath=str(checkpoint_dir),
            filename="last",
            save_last=True,
            save_top_k=1,
        )
        if flyte_checkpoint else
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename="best-{epoch:03d}-{val_acc:.3f}",
            monitor="val/acc",
            mode="max",
            save_top_k=3,
            verbose=True,
            auto_insert_metric_name=False,
        )
    )

    callbacks = [
        resume_callback,
        metrics_cb,
        PhaseChangeCallback(
            phase1_epochs=config.phase1_epochs,
            phase2_lr=config.phase2_lr,
        ),
    ]

    trainer = L.Trainer(
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision="16-mixed",
        callbacks=callbacks,
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=config.log_interval,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)

    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(f"\n✓ Training complete!")
    print(f"Best checkpoint: {best_checkpoint}")

    # Final inference with TTA (test-time augmentation): average logits over
    # original + h-flip + v-flip + 90° rotations for a free accuracy boost.
    print("\nRunning final inference with TTA for confusion matrix...")
    import numpy as np
    import torchvision.transforms.functional as TF
    model.eval()
    model.to(device)
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            aug_logits = [
                model.model(images),
                model.model(TF.hflip(images)),
                model.model(TF.vflip(images)),
                model.model(torch.rot90(images, k=1, dims=[2, 3])),
                model.model(torch.rot90(images, k=3, dims=[2, 3])),
            ]
            avg_logits = torch.stack(aug_logits).mean(dim=0)
            all_preds.append(avg_logits.argmax(dim=1).cpu())
            all_targets.append(labels.cpu())
    final_preds = torch.cat(all_preds).numpy()
    final_targets = torch.cat(all_targets).numpy()

    return {
        "best_checkpoint": best_checkpoint,
        "metrics": metrics_cb.history,
        "final_preds": final_preds.tolist(),
        "final_targets": final_targets.tolist(),
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/training.py*

On startup, the training loop looks for a checkpoint from a previous attempt and resumes from it if present:

```
"""
Training pipeline for brain tumor MRI classification.

Implements two-phase training:
- Phase 1: Frozen backbone (feature extractor), train classification head
- Phase 2: Fine-tune full model with differential LRs + cosine annealing
"""

from config import TrainingConfig
from dataset import compute_class_weights, create_data_loaders
from model import TumorClassifierLightningModule
from utils import get_model_size, get_trainable_params

def train_tumor_classifier(
    config: TrainingConfig,
    dataset_path: str,
) -> dict:
    """
    Run two-phase training on the preprocessed dataset and return metrics + final predictions.

    dataset_path: local directory where the flyte.io.Dir was downloaded by the training task.
    """
    import pathlib

    import flyte
    import lightning as L
    import torch
    from lightning.pytorch.callbacks import ModelCheckpoint
    from typing_extensions import override

    # {{docs-fragment flyte_checkpoint}}
    class FlyteLightningCheckpointCallback(ModelCheckpoint):
        """Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""

        def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
            super().__init__(dirpath=dirpath, **kwargs)
            self._flyte_checkpoint = flyte_checkpoint

        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            super().on_train_epoch_end(trainer, pl_module)
            if self.dirpath:
                self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
    # {{/docs-fragment flyte_checkpoint}}

    class MetricsLoggerCallback(L.Callback):
        def __init__(self, phase1_epochs: int):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.history = []

        def on_validation_epoch_end(self, trainer, _pl_module):
            epoch = trainer.current_epoch
            metrics = trainer.callback_metrics
            self.history.append({
                "epoch": epoch,
                "phase": 1 if epoch < self.phase1_epochs else 2,
                "train_loss": float(metrics.get("train/loss_epoch", 0)),
                "val_loss": float(metrics.get("val/loss", 0)),
                "val_acc": float(metrics.get("val/acc", 0)),
                "macro_f1": float(metrics.get("val/macro_f1", 0)),
            })

    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True

    print("\n" + "=" * 80)
    print("BRAIN TUMOR MRI CLASSIFICATION WITH EFFICIENTNET-B4")
    print("=" * 80)
    print(f"Config: {config.to_dict()}\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    print("\nLoading MRI images...")
    train_loader, val_loader = create_data_loaders(
        dataset_path=dataset_path,
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        val_split=config.val_split,
    )
    print(f"Data loaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")

    print("\nComputing class weights for focal loss...")
    class_weights = compute_class_weights(dataset_path)
    print(f"Class weights: {class_weights.tolist()}")

    # Per-class gamma: Meningioma gets 7.0, all others 3.0.
    # CLASS_NAMES alphabetical order: Glioma=0, Meningioma=1, No Tumor=2, Pituitary=3
    gamma_per_class = torch.tensor([3.0, 7.0, 3.0, 3.0])

    print("\nInitializing model...")
    model = TumorClassifierLightningModule(
        num_classes=config.num_classes,
        model_name=config.model_name,
        pretrained=config.pretrained,
        learning_rate=config.phase1_lr,
        freeze_backbone=config.phase1_freeze_backbone,
        weight_decay=config.weight_decay,
        warmup_steps=config.warmup_steps,
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        focal_gamma=config.focal_gamma,
        mixup_alpha=config.mixup_alpha,
        class_weights=class_weights,
        gamma_per_class=gamma_per_class,
    )

    print(f"Model: {config.model_name}")
    print(f"Total parameters: {get_model_size(model.model):,}")
    print(f"Trainable parameters: {get_trainable_params(model.model):,}")

    from pathlib import Path
    checkpoint_dir = Path("/tmp/tumor_checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # {{docs-fragment resume}}
    # --- Flyte checkpoint: resume from previous attempt if one exists ---
    resume_ckpt: str | None = None
    ctx = flyte.ctx()
    flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None

    if flyte_checkpoint:
        prev_path = flyte_checkpoint.load_sync()
        if prev_path:
            last = flyte.latest_checkpoint(prev_path)
            if last:
                ck = torch.load(str(last), map_location="cpu", weights_only=False)
                epoch_start = int(ck.get("epoch", 0))
                resume_ckpt = str(last)
                print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
    # --------------------------------------------------------------------
    # {{/docs-fragment resume}}

    metrics_cb = MetricsLoggerCallback(phase1_epochs=config.phase1_epochs)

    resume_callback = (
        FlyteLightningCheckpointCallback(
            flyte_checkpoint,
            dirpath=str(checkpoint_dir),
            filename="last",
            save_last=True,
            save_top_k=1,
        )
        if flyte_checkpoint else
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename="best-{epoch:03d}-{val_acc:.3f}",
            monitor="val/acc",
            mode="max",
            save_top_k=3,
            verbose=True,
            auto_insert_metric_name=False,
        )
    )

    callbacks = [
        resume_callback,
        metrics_cb,
        PhaseChangeCallback(
            phase1_epochs=config.phase1_epochs,
            phase2_lr=config.phase2_lr,
        ),
    ]

    trainer = L.Trainer(
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision="16-mixed",
        callbacks=callbacks,
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=config.log_interval,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)

    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(f"\n✓ Training complete!")
    print(f"Best checkpoint: {best_checkpoint}")

    # Final inference with TTA (test-time augmentation): average logits over
    # original + h-flip + v-flip + 90° rotations for a free accuracy boost.
    print("\nRunning final inference with TTA for confusion matrix...")
    import numpy as np
    import torchvision.transforms.functional as TF
    model.eval()
    model.to(device)
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            aug_logits = [
                model.model(images),
                model.model(TF.hflip(images)),
                model.model(TF.vflip(images)),
                model.model(torch.rot90(images, k=1, dims=[2, 3])),
                model.model(torch.rot90(images, k=3, dims=[2, 3])),
            ]
            avg_logits = torch.stack(aug_logits).mean(dim=0)
            all_preds.append(avg_logits.argmax(dim=1).cpu())
            all_targets.append(labels.cpu())
    final_preds = torch.cat(all_preds).numpy()
    final_targets = torch.cat(all_targets).numpy()

    return {
        "best_checkpoint": best_checkpoint,
        "metrics": metrics_cb.history,
        "final_preds": final_preds.tolist(),
        "final_targets": final_targets.tolist(),
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/training.py*

## Generate the report

The reporting task reads the metrics and predictions, then renders accuracy/loss curves, a confusion matrix, and a per-class F1 chart with Plotly. `report=True` surfaces the HTML directly in the run's report panel.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Orchestrate the pipeline

The driver task wires the three steps together.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Run the pipeline

This example has no secrets — the dataset is public. Because the pipeline imports sibling modules and uses `with_source_folder`, run it from inside the example directory so the local files are picked up.

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/tumor_detection):

```
cd v2/tutorials/tumor_detection
python run.py
```

Or submit it with the Flyte CLI from the same directory:

```
flyte run run.py tumor_detection_pipeline
```

When the run completes, open the `create_report` task in the UI to view the training curves, confusion matrix, and per-class F1 scores.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/genomic-gene-comparison ===

# Cross-species gene comparison

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_gene_comparison).

This tutorial builds a bioinformatics pipeline that compares homologous genes across species. The workflow loads curated gene sequences (insulin, hemoglobin, or p53 by default), scores each sequence with the [Carbon](https://huggingface.co/HuggingFaceBio/Carbon-3B) genomic language model, aligns DNA and translated protein sequences, folds proteins with [ESMFold](https://github.com/facebookresearch/esm), and renders interactive HTML reports with identity heatmaps, phylogenetic trees, and 3D structure viewers.

Flyte makes the multi-stage GPU/CPU pipeline reliable:

- **Separate CPU and GPU `TaskEnvironment`s** so alignment runs on modest CPU boxes while Carbon scoring and ESMFold run on GPUs.
- **`report=True`** on every stage for live HTML progress and final summaries in the Flyte UI.
- **Cached data loading** and orchestrated fan-out across pipeline stages.

## Define the task environments

GPU tasks handle Carbon log-likelihood scoring and ESMFold structure prediction; CPU tasks load gene sets, run Needleman-Wunsch alignments, and generate the final summary.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="genomic-gene-comparison", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="genomic-gene-comparison-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="32Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="genomic-gene-comparison-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Homologous gene sets - same gene across species
# ------------------------------------------------------------------
# Full-length coding sequences from NCBI RefSeq (stop codon excluded).

GENE_SETS = {
    "insulin": {
        "gene_name": "Insulin",
        "description": "Insulin regulates blood sugar in all vertebrates. Highly conserved across 500M+ years of evolution - even fish insulin can lower blood sugar in humans. Comparing across species reveals which regions are functionally essential (conserved) vs free to drift.",
        "sequences": {
            "Human": {
                "dna": "ATGGCCCTGTGGATGCGCCTCCTGCCCCTGCTGGCGCTGCTGGCCCTCTGGGGACCTGACCCAGCCGCAGCCTTTGTGAACCAACACCTGTGCGGCTCACACCTGGTGGAAGCTCTCTACCTAGTGTGCGGGGAACGAGGCTTCTTCTACACACCCAAGACCCGCCGGGAGGCAGAGGACCTGCAGGTGGGGCAGGTGGAGCTGGGCGGGGGCCCTGGTGCAGGCAGCCTGCAGCCCTTGGCCCTGGAGGGGTCCCTGCAGAAGCGTGGCATTGTGGAACAATGCTGTACCAGCATCTGCTCCCTCTACCAGCTGGAGAACTACTGCAAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGGCCCTGTGGATGCGCTTCCTGCCCCTGCTGGCCCTGCTCTTCCTCTGGGAGTCCCACCCCACCCAGGCTTTTGTCAAGCAGCACCTTTGTGGTTCCCACCTGGTGGAGGCTCTCTACCTGGTGTGTGGGGAGCGTGGCTTCTTCTACACACCCATGTCCCGCCGTGAAGTGGAGGACCCACAAGTGGCACAACTGGAGCTGGGTGGAGGCCCGGGAGCAGGTGACCTTCAGACCTTGGCACTGGAGGTGGCCCAGCAGAAGCGTGGCATTGTAGATCAGTGCTGCACCAGCATCTGCTCCCTCTACCAGCTGGAGAACTACTGCAAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGCTCTCTGGATCCGATCACTGCCTCTTCTGGCTCTCCTTGTCTTTTCTGGCCCTGGAACCAGCTATGCAGCTGCCAACCAGCACCTCTGTGGCTCCCACTTGGTGGAGGCTCTCTACCTGGTGTGTGGAGAGCGTGGCTTCTTCTACTCCCCCAAAGCCCGACGGGATGTCGAGCAGCCCCTAGTGAGCAGTCCCTTGCGTGGCGAGGCAGGAGTGCTGCCTTTCCAGCAGGAGGAATACGAGAAAGTCAAGCGAGGGATTGTTGAGCAATGCTGCCATAACACGTGTTCCCTCTACCAACTGGAGAACTACTGCAAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGCAGTGTGGCTTCAGGCTGGTGCTCTGTTGGTCCTGTTGGTCGTGTCCAGTGTAAGCACTAACCCAGGCACACCGCAGCACCTGTGTGGATCTCATCTGGTCGATGCCCTTTATCTGGTCTGTGGCCCAACAGGCTTCTTCTACAACCCCAAGAGAGACGTTGAGCCCCTTCTGGGTTTCCTTCCTCCTAAATCTGCCCAGGAAACTGAGGTGGCTGACTTTGCATTTAAAGATCATGCCGAGCTGATAAGGAAGAGAGGCATTGTAGAGCAGTGCTGCCACAAACCCTGCAGCATCTTTGAGCTGCAGAACTACTGTAAC",
                "common_name": "Danio rerio",
            },
            "Frog": {
                "dna": "ATGGCTCTATGGATGCAGTGTCTGCCCCTGGTTCTTGTCCTCTTTTTCTCTACACCCAACACCGAAGCTCTAGTTAACCAGCACTTGTGTGGGTCTCACCTGGTAGAAGCCCTGTACTTAGTATGTGGGGATCGAGGCTTCTTCTACTACCCTAAGGTCAAACGGGACATGGAACAAGCACTTGTCAGTGGACCCCAGGATAATGAGTTGGATGGAATGCAGCTCCAGCCTCAGGAGTATCAGAAAATGAAGAGGGGGATTGTGGAGCAATGTTGCCACAGCACATGTTCTCTCTTCCAGCTGGAGAGTTACTGCAAC",
                "common_name": "Xenopus laevis",
            },
            "Cow": {
                "dna": "ATGGCCCTGTGGACACGCCTGGCGCCCCTGCTGGCCCTGCTGGCGCTCTGGGCCCCCGCCCCGGCCCGCGCCTTCGTCAACCAGCATCTGTGTGGCTCCCACCTGGTGGAGGCGCTGTACCTGGTGTGCGGAGAGCGCGGCTTCTTCTACACGCCCAAGGCCCGCCGGGAGGTGGAGGGCCCCCAGGTGGGGGCGCTGGAGCTGGCCGGAGGCCCGGGCGCGGGCGGCCTGGAGGGGCCCCCGCAGAAGCGTGGCATCGTGGAGCAGTGCTGTGCCAGCGTCTGCTCGCTCTACCAGCTGGAGAACTACTGTAAC",
                "common_name": "Bos taurus",
            },
        },
    },
    "hemoglobin": {
        "gene_name": "Hemoglobin Beta",
        "description": "Beta-globin carries oxygen from lungs to tissues. The most studied gene in molecular evolution - sequence differences power the 'molecular clock' hypothesis. Sickle cell mutation (E6V) in humans shows how a single base change creates devastating disease.",
        "sequences": {
            "Human": {
                "dna": "ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTGTGGGGCAAGGTGAACGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTGGTCTACCCTTGGACCCAGAGGTTCTTTGAGTCCTTTGGGGATCTGTCCACTCCTGATGCTGTTATGGGCAACCCTAAGGTGAAGGCTCATGGCAAGAAAGTGCTCGGTGCCTTTAGTGATGGCCTGGCTCACCTGGACAACCTCAAGGGCACCTTTGCCACACTGAGTGAGCTGCACTGTGACAAGCTGCACGTGGATCCTGAGAACTTCAGGCTCCTGGGCAACGTGCTGGTCTGTGTGCTGGCCCATCACTTTGGCAAAGAATTCACCCCACCAGTGCAGGCTGCCTATCAGAAAGTGGTGGCTGGTGTGGCTAATGCCCTGGCCCACAAGTATCAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGGTGCACCTGACTGATGCTGAGAAGGCTGCTGTCTCTGGCCTGTGGGGAAAGGTGAACGCCGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTTGTCTACCCTTGGACCCAGCGGTACTTTGATAGCTTTGGAGACCTATCCTCTGCCTCTGCTATCATGGGTAATGCCAAAGTGAAGGCCCATGGCAAGAAAGTGATAACTGCCTTTAACGATGGCCTGAATCACTTGGACAGCCTCAAGGGCACCTTTGCCAGCCTCAGTGAGCTCCACTGTGACAAGCTGCATGTGGATCCTGAGAACTTCAGGCTCCTGGGCAATATGATCGTGATTGTGCTGGGCCACCACCTGGGCAAGGATTTCACCCCCGCTGCACAGGCTGCCTTCCAGAAGGTGGTGGCTGGAGTGGCTGCTGCCCTGGCTCACAAGTACCAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGTGCACTGGACTGCTGAGGAGAAGCAGCTCATCACCGGCCTCTGGGGCAAGGTCAATGTGGCCGAATGTGGGGCTGAAGCCCTGGCCAGGCTGCTGATCGTCTACCCCTGGACCCAGAGGTTCTTTGCGTCCTTTGGGAACCTCTCCAGCCCCACTGCCATCCTTGGCAACCCCATGGTCCGCGCCCATGGCAAGAAAGTGCTCACCTCCTTTGGGGATGCTGTGAAGAACCTGGACAACATCAAGAACACCTTCTCCCAACTGTCCGAACTGCATTGTGACAAGCTGCATGTGGACCCCGAGAACTTCAGGCTCCTGGGTGACATCCTCATCATTGTCCTGGCCGCCCACTTCAGCAAGGACTTCACTCCTGAATGCCAGGCTGCCTGGCAGAAGCTGGTCCGCGTGGTGGCCCATGCCCTGGCTCGCAAGTACCAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGTTGAGTGGACAGATGCCGAGCGCACAGCCATCCTTGGCCTGTGGGGAAAGCTCAATATCGATGAAATCGGACCTCAGGCCCTATCCAGATGTCTGATCGTGTATCCCTGGACTCAGAGATATTTCGCCACATTCGGCAACCTGTCAAGCCCCGCTGCGATCATGGGTAACCCCAAAGTGGCAGCTCATGGGAGGACTGTGATGGGAGGTCTTGAGAGAGCCATCAAGAACATGGACAACGTCAAGAACACCTATGCCGCCCTCAGTGTGATGCACTCTGAGAAACTGCATGTGGATCCCGACAACTTCAGGCTTCTCGCTGATTGCATCACCGTTTGCGCTGCCATGAAGTTCGGCCAAGCTGGTTTCAATGCTGATGTCCAGGAGGCCTGGCAGAAGTTTCTGGCTGTGGTCGTTTCTGCTCTGTGCAGACAGTACCAC",
                "common_name": "Danio rerio",
            },
            "Frog": {
                "dna": "ATGGTTCATTGGACAGCTGAAGAGAAGGCCGCCATCACCTCTGTGTGGCAGGAGGTCAACCAGGAGCAAGATGGCCATGATGCACTCACAAGGCTGCTGGTTGTGTACCCCTGGACCCAGAGATACTTCAGCAGTTTTGGAAATCTCGGTAATGCCACAGCTATTGCTGGAAATGTCAAGGTGCGTGCCCATGGCAAGAAGGTTCTTTCAGCTGTTGGTGATGCCATCGCCCATCTTGACAACGTGAAGGGAACTCTCCATGACCTCAGTGTGGTCCACGCCTTCAAGCTCTATGTGGATCCTGAGAACTTCAAGCGTCTTGGTGAAGTTCTGGTCATTGTCTTGGCTTCCAAACTGGGATCAGCCTTTACTCCTCAAGTCCAGGGAGCCTGGGAGAAATTTGTTGCTGTTCTGGTTGATGCCCTCAGCCAAGGATACAAC",
                "common_name": "Xenopus laevis",
            },
            "Cow": {
                "dna": "ATGCTGACTGCTGAGGAGAAGGCTGCCGTCACCGCCTTTTGGGGCAAGGTGAAAGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTTGTCTACCCCTGGACTCAGAGGTTCTTTGAGTCCTTTGGGGACTTGTCCACTGCTGATGCTGTTATGAACAACCCTAAGGTGAAGGCCCATGGCAAGAAGGTGCTAGATTCCTTTAGTAATGGCATGAAGCATCTCGATGACCTCAAGGGCACCTTTGCTGCGCTGAGTGAGCTGCACTGTGATAAGCTGCATGTGGATCCTGAGAACTTCAAGCTCCTGGGCAACGTGCTAGTGGTTGTGCTGGCTCGCAATTTTGGCAAGGAATTCACCCCGGTGCTGCAGGCTGACTTTCAGAAGGTGGTGGCTGGTGTGGCCAATGCCCTGGCCCACAGATATCAT",
                "common_name": "Bos taurus",
            },
        },
    },
    "p53": {
        "gene_name": "p53 (TP53)",
        "description": "The 'guardian of the genome' - p53 detects DNA damage and triggers repair or cell death. Mutated in >50% of human cancers. Elephants have 20 copies of p53 (humans have 1), which may explain their extremely low cancer rates despite their size (Peto's paradox).",
        "sequences": {
            "Human": {
                "dna": "ATGGAGGAGCCGCAGTCAGATCCTAGCGTCGAGCCCCCTCTGAGTCAGGAAACATTTTCAGACCTATGGAAACTACTTCCTGAAAACAACGTTCTGTCCCCCTTGCCGTCCCAAGCAATGGATGATTTGATGCTGTCCCCGGACGATATTGAACAATGGTTCACTGAAGACCCAGGTCCAGATGAAGCTCCCAGAATGCCAGAGGCTGCTCCCCCCGTGGCCCCTGCACCAGCAGCTCCTACACCGGCGGCCCCTGCACCAGCCCCCTCCTGGCCCCTGTCATCTTCTGTCCCTTCCCAGAAAACCTACCAGGGCAGCTACGGTTTCCGTCTGGGCTTCTTGCATTCTGGGACAGCCAAGTCTGTGACTTGCACGTACTCCCCTGCCCTCAACAAGATGTTTTGCCAACTGGCCAAGACCTGCCCTGTGCAGCTGTGGGTTGATTCCACACCCCCGCCCGGCACCCGCGTCCGCGCCATGGCCATCTACAAGCAGTCACAGCACATGACGGAGGTTGTGAGGCGCTGCCCCCACCATGAGCGCTGCTCAGATAGCGATGGTCTGGCCCCTCCTCAGCATCTTATCCGAGTGGAAGGAAATTTGCGTGTGGAGTATTTGGATGACAGAAACACTTTTCGACATAGTGTGGTGGTGCCCTATGAGCCGCCTGAGGTTGGCTCTGACTGTACCACCATCCACTACAACTACATGTGTAACAGTTCCTGCATGGGCGGCATGAACCGGAGGCCCATCCTCACCATCATCACACTGGAAGACTCCAGTGGTAATCTACTGGGACGGAACAGCTTTGAGGTGCGTGTTTGTGCCTGTCCTGGGAGAGACCGGCGCACAGAGGAAGAGAATCTCCGCAAGAAAGGGGAGCCTCACCACGAGCTGCCCCCAGGGAGCACTAAGCGAGCACTGCCCAACAACACCAGCTCCTCTCCCCAGCCAAAGAAGAAACCACTGGATGGAGAATATTTCACCCTTCAGATCCGTGGGCGTGAGCGCTTCGAGATGTTCCGAGAGCTGAATGAGGCCTTGGAACTCAAGGATGCCCAGGCTGGGAAGGAGCCAGGGGGGAGCAGGGCTCACTCCAGCCACCTGAAGTCCAAAAAGGGTCAGTCTACCTCCCGCCATAAAAAACTCATGTTCAAGACAGAAGGGCCTGACTCAGAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGACTGCCATGGAGGAGTCACAGTCGGATATCAGCCTCGAGCTCCCTCTGAGCCAGGAGACATTTTCAGGCTTATGGAAACTACTTCCTCCAGAAGATATCCTGCCATCACCTCACTGCATGGACGATCTGTTGCTGCCCCAGGATGTTGAGGAGTTTTTTGAAGGCCCAAGTGAAGCCCTCCGAGTGTCAGGAGCTCCTGCAGCACAGGACCCTGTCACCGAGACCCCTGGGCCAGTGGCCCCTGCCCCAGCCACTCCATGGCCCCTGTCATCTTTTGTCCCTTCTCAAAAAACTTACCAGGGCAACTATGGCTTCCACCTGGGCTTCCTGCAGTCTGGGACAGCCAAGTCTGTTATGTGCACGTACTCTCCTCCCCTCAATAAGCTATTCTGCCAGCTGGCGAAGACGTGCCCTGTGCAGTTGTGGGTCAGCGCCACACCTCCAGCTGGGAGCCGTGTCCGCGCCATGGCCATCTACAAGAAGTCACAGCACATGACGGAGGTCGTGAGACGCTGCCCCCACCATGAGCGCTGCTCCGATGGTGATGGCCTGGCTCCTCCCCAGCATCTTATCCGGGTGGAAGGAAATTTGTATCCCGAGTATCTGGAAGACAGGCAGACTTTTCGCCACAGCGTGGTGGTACCTTATGAGCCACCCGAGGCCGGCTCTGAGTATACCACCATCCACTACAAGTACATGTGTAATAGCTCCTGCATGGGGGGCATGAACCGCCGACCTATCCTTACCATCATCACACTGGAAGACTCCAGTGGGAACCTTCTGGGACGGGACAGCTTTGAGGTTCGTGTTTGTGCCTGCCCTGGGAGAGACCGCCGTACAGAAGAAGAAAATTTCCGCAAAAAGGAAGTCCTTTGCCCTGAACTGCCCCCAGGGAGCGCAAAGAGAGCGCTGCCCACCTGCACAAGCGCCTCTCCCCCGCAAAAGAAAAAACCACTTGATGGAGAGTATTTCACCCTCAAGATCCGCGGGCGTAAACGCTTCGAGATGTTCCGGGAGCTGAATGAGGCCTTAGAGTTAAAGGATGCCCATGCTACAGAGGAGTCTGGAGACAGCAGGGCTCACTCCAGCTACCTGAAGACCAAGAAGGGCCAGTCTACTTCCCGCCATAAAAAAACAATGGTCAAGAAAGTGGGGCCTGACTCAGAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGCGGAGGAGATGGAACCATTGCTGGAACCCACTGAGGTCTTCATGGACCTCTGGAGCATGCTCCCCTATAGCATGCAACAGCTGCCCCTCCCTGAGGATCACAGCAACTGGCAGGAGCTGAGCCCCCTGGAACCCAGCGACCCCCCCCCACCACCGCCACCACCACCTCTGCCATTGGCCGCCGCCGCCCCCCCCCCATTAAACCCCCCCACCCCCCCCCGCGCTGCCCCCTCCCCGGTGGTCCCATCCACGGAGGATTATGGGGGGGACTTCGACTTCCGGGTGGGGTTCGTGGAGGCGGGCACAGCCAAATCGGTCACCTGCACTTACTCCCCGGTGCTGAATAAGGTCTATTGCCGCCTGGCCAAGCCGTGCCCGGTGCAGGTGAGGGTGGGGGTGGCGCCCCCCCCCGGTTCCTCCCTCCGCGCCGTGGCCGTCTATAAGAAATCAGAGCACGTGGCCGAAGTGGTGCGGCGCTGCCCCCACCACGAGCGCTGCGGGGGGGGCACCGACGGCCTGGCCCCCGCACAGCACCTCATCCGGGTGGAGGGGAACCCCCAGGCGCGTTACCACGACGACGAGACCACCAAACGGCACAGCGTCGTCGTCCCCTATGAGCCCCCCGAGGTGGGCTCTGACTGTACCACGGTGCTGTACAACTTCATGTGCAACAGTTCCTGCATGGGGGGGATGAACCGCCGCCCCATCCTCACCATCCTTACACTGGAGGGGCCGGGGGGGCAGCTGTTGGGGCGGCGCTGCTTCGAGGTGCGCGTGTGCGCATGTCCGGGGAGGGACCGCAAGATCGAGGAGGAGAACTTCCGCAAGAGGGGCGGGGCCGGGGGCGTGGCTAAGCGAGCCATGTCGCCCCCAACCGAAGCCCCCGAGCCCCCCAAGAAGCGCGTGCTGAACCCCGACAATGAGATATTCTACCTGCAGGTGCGCGGGCGCCGCCGCTATGAGATGCTGAAGGAGATCAATGAGGCGCTGCAGCTCGCCGAGGGGGGGTCCGCACCGCGGCCTTCCAAAGGCCGCCGTGTGAAGGTGGAGGGACCCCAACCCAGCTGCGGGAAGAAACTGCTGCAAAAAGGCTCGGAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGCGCAAAACGACAGCCAAGAGTTCGCGGAGCTCTGGGAGAAGAATTTGATTATTCAGCCCCCAGGTGGTGGCTCTTGCTGGGACATCATTAATGATGAGGAGTACTTGCCGGGATCGTTTGACCCCAATTTTTTTGAAAATGTGCTTGAAGAACAGCCTCAGCCATCCACTCTCCCACCAACATCCACTGTTCCGGAGACAAGCGACTATCCCGGCGATCATGGATTTAGGCTCAGGTTCCCGCAGTCTGGCACAGCAAAATCTGTAACTTGCACTTATTCACCGGACCTGAATAAACTCTTCTGTCAGCTGGCAAAAACTTGCCCCGTTCAAATGGTGGTGGACGTTGCCCCTCCACAGGGCTCCGTGGTTCGAGCCACTGCCATCTATAAGAAGTCCGAGCATGTGGCTGAAGTGGTCCGCAGATGCCCCCATCATGAGCGAACCCCGGATGGAGATAACTTGGCGCCTGCTGGTCATTTGATAAGAGTGGAGGGCAATCAGCGAGCAAATTACAGGGAAGATAACATCACTTTAAGGCATAGTGTTTTTGTCCCATATGAAGCACCACAGCTTGGTGCTGAATGGACAACTGTGCTACTAAACTACATGTGCAATAGCAGCTGCATGGGGGGGATGAACCGCAGGCCCATCCTCACAATCATCACTCTGGAGACTCAGGAAGGTCAGTTGCTGGGCCGGAGGTCTTTTGAGGTGCGTGTGTGTGCATGTCCAGGCAGAGACAGGAAAACTGAGGAGAGCAACTTCAAGAAAGACCAAGAGACCAAAACCATGGCCAAAACCACCACTGGGACCAAACGTAGTTTGGTGAAAGAATCTTCTTCAGCTACATTACGACCTGAGGGGAGCAAAAAGGCCAAGGGCTCCAGCAGCGATGAGGAGATCTTTACCCTGCAGGTGAGGGGCAGGGAGCGTTATGAAATTTTAAAGAAATTGAACGACAGTCTGGAGTTAAGTGATGTGGTGCCTGCCTCAGATGCTGAAAAGTATCGTCAGAAATTCATGACAAAAAACAAAAAAGAGAATCGTGAATCATCTGAGCCCAAACAGGGAAAGAAGCTGATGGTGAAGGACGAAGGAAGAAGCGACTCTGAT",
                "common_name": "Danio rerio",
            },
            "Elephant": {
                "dna": "ATGGAGGAGCCCCAGTCAGATCTCAGCACCGAGCTCCCTCTGAGTCAAGAGACGTTTTCATACTTATGGGAACTCCTTCCTGAGAATCCGGTTCTGTCCCCCACACTACCCCCGGCAGTGGAGGTCATGGACGATCTGCTACTCTCAGAAGACACTGCAAACTGGCTAGAAAGCCAAGTTGAGGCTCAGGGAATGTCCACAACCCCTGCACCAGCCACCCCTACACCGGTGGCCCCCGCACCAGCCACCTCCTGGACCCTGTCATCTTCCGTCCCTTCCCAAAAGACCTACCCTGGCACCTATGGTTTCCGTCTGGGCTTCCTACATTCTGGGACAGCCAAGTCCGTCACCTGCACGTACTCCCCTGACCTTAACAAGCTGTTTTGCCAGCTGGCAAAAACCTGCCCAGTGCAGCTGTGGGTCGCCTCACCACCCCCGCCCGGCACCCGTGTTCGCACCATGGCCATCTACAAGAAGTCAGAGCATATGACGGAGGTCGTCAAGCGCTGCCCCCACCATGAGCGCTGCTCTGACTCTAGCGATGGCCTGGCCCCTCCTCAGCACCTCATCCGGGTGGAAGGAAACCTGCGTGCTGAGTATCTGGAGGACAGCATCACTCTCCGACACAGTGTGGTGGTGCCCTACGAGCCGCCCGAGGTTGGGTCTGACTGTACCACCATCCACTTCAACTTCATGTGTAACAGCTCCTGCATGGGGGGCATGAACCGGCGGCCCATCCTCACCATCATCACACTGGAAGACTCCAGTGGTAATCTGCTGGGACGTAACAGCTTTGAGGTGCGCATTTGTGCCTGTCCTGGAAGAGACAGACGTACAGAAGAAGAAAATTTCCACAAGAAGGGAGAGCCTTGCCCAGAGCCGCCACCCCCTGGGAGGAGCACTAAGCGAGCACTGCCCACCAACACCAGCTCCTCTACCCAGCCAAAGAAGAAGCCACTGGATGAAGAATATTTCACCCTTCAGATCCGTGGGCGTGAACGCTTCAAGATGTTCCTAGAGCTAAATGAGGCCTTGGAGCTGAAGGATGCCCAGGCTGGGAAGGAGCCAGAGGGGAGCCGGGCTCACTCCAGCCCTTCGAAGTCTAAGAAGGGACAGTCTACCTCCCGCCATAAAAAACCAATGTTCAAGAGAGAGGGACCTGACTCAGAC",
                "common_name": "Loxodonta africana",
            },
            "Dog": {
                "dna": "ATGGAGGAGTCGCAGTCAGAGCTCAATATCGACCCCCCTCTGAGCCAGGAGACATTTTCAGAATTGTGGAACCTGCTTCCTGAAAACAATGTTCTGTCTTCGGAGCTGTGCCCAGCAGTGGATGAGCTGCTGCTCCCAGAGAGCGTCGTGAACTGGCTAGACGAAGACTCAGATGATGCTCCCAGGATGCCAGCCACTTCTGCCCCCACAGCCCCTGGACCGGCCCCCTCGTGGCCCCTATCATCCTCTGTCCCTTCCCCGAAGACCTACCCTGGCACCTATGGGTTCCGTTTGGGGTTCCTGCATTCCGGGACAGCCAAGTCTGTTACTTGGACGTACTCCCCTCTCCTCAACAAGTTGTTTTGCCAGCTGGCGAAGACCTGCCCCGTGCAGCTGTGGGTCAGCTCCCCACCCCCACCCAATACCTGCGTCCGCGCTATGGCCATCTATAAGAAGTCGGAGTTCGTGACCGAGGTTGTGCGGCGCTGCCCCCACCATGAACGCTGCTCTGACAGTAGTGACGGTCTTGCCCCTCCTCAGCATCTCATCCGAGTGGAAGGAAATTTGCGGGCCAAGTACCTGGACGACAGAAACACTTTTCGACACAGTGTGGTGGTGCCTTATGAGCCACCCGAGGTTGGCTCTGACTATACCACCATCCACTACAACTACATGTGTAACAGTTCCTGCATGGGAGGCATGAACCGGCGGCCCATCCTCACTATCATCACCCTGGAAGACTCCAGTGGAAACGTGCTGGGACGCAACAGCTTTGAGGTACGCGTTTGTGCCTGTCCCGGGAGAGACCGCCGGACTGAGGAGGAGAATTTCCACAAGAAGGGGGAGCCTTGTCCTGAGCCACCCCCCGGGAGTACCAAGCGAGCACTGCCTCCCAGCACCAGCTCCTCTCCCCCGCAAAAGAAGAAGCCACTAGATGGAGAATATTTCACCCTTCAGATCCGTGGGCGTGAACGCTATGAGATGTTCAGGAATCTGAATGAAGCCTTGGAGCTGAAGGATGCCCAGAGTGGAAAGGAGCCAGGGGGAAGCAGGGCTCACTCCAGCCACCTGAAGGCAAAGAAGGGGCAATCTACCTCTCGCCATAAAAAACTGATGTTCAAGAGAGAAGGGCTTGACTCAGAC",
                "common_name": "Canis lupus familiaris",
            },
        },
    },
}

# Standard genetic code
CODON_TABLE = {
    "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", "CTT": "L", "CTC": "L",
    "CTA": "L", "CTG": "L", "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
    "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", "TCT": "S", "TCC": "S",
    "TCA": "S", "TCG": "S", "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
    "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", "GCT": "A", "GCC": "A",
    "GCA": "A", "GCG": "A", "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
    "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", "AAT": "N", "AAC": "N",
    "AAA": "K", "AAG": "K", "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
    "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", "CGT": "R", "CGC": "R",
    "CGA": "R", "CGG": "R", "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
    "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
}

BASE_COLORS = {"A": "#2ecc71", "T": "#e74c3c", "G": "#f39c12", "C": "#3498db"}

def _translate(dna: str) -> str:
    """Translate DNA to protein in reading frame 0."""
    dna = dna.upper()
    protein = []
    for i in range(0, len(dna) - 2, 3):
        codon = dna[i:i + 3]
        aa = CODON_TABLE.get(codon, "X")
        if aa == "*":
            break
        protein.append(aa)
    return "".join(protein)

def _gc_content(seq: str) -> float:
    if not seq:
        return 0.0
    return sum(1 for b in seq.upper() if b in "GC") / len(seq)

def _sequence_identity(seq1: str, seq2: str, match: int = 2, mismatch: int = -1, gap: int = -2) -> float:
    """Percent identity via Needleman-Wunsch global alignment."""
    if not seq1 or not seq2:
        return 0.0
    n, m = len(seq1), len(seq2)

    # Build score matrix
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        dp[i][0] = dp[i - 1][0] + gap
    for j in range(1, m + 1):
        dp[0][j] = dp[0][j - 1] + gap
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            s = match if seq1[i - 1] == seq2[j - 1] else mismatch
            dp[i][j] = max(dp[i - 1][j - 1] + s, dp[i - 1][j] + gap, dp[i][j - 1] + gap)

    # Traceback to count matches and alignment length
    i, j = n, m
    matches = 0
    aligned = 0
    while i > 0 or j > 0:
        if i > 0 and j > 0:
            s = match if seq1[i - 1] == seq2[j - 1] else mismatch
            if dp[i][j] == dp[i - 1][j - 1] + s:
                if seq1[i - 1] == seq2[j - 1]:
                    matches += 1
                aligned += 1
                i -= 1
                j -= 1
                continue
        if i > 0 and dp[i][j] == dp[i - 1][j] + gap:
            aligned += 1
            i -= 1
        else:
            aligned += 1
            j -= 1

    return matches / aligned if aligned else 0.0

# ------------------------------------------------------------------
# Report styling
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #1e3a5f; border-bottom: 2px solid #2563eb; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #1e40af; margin-top: 20px; }
  .report .card { background: #eff6ff; border: 1px solid #bfdbfe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #dbeafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #1e3a5f; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #1e3a5f; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dbeafe; }
  .report tr:nth-child(even) { background: #eff6ff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-warning { background: #fef3c7; color: #92400e; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #dbeafe; color: #1e40af; }
  .report .chart-container { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #eff6ff; border-left: 4px solid #2563eb; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .structure-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(340px, 1fr)); gap: 16px; margin: 12px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    width: int = 600,
    height: int = 500,
    value_format: str = ".1f",
    color_scale: str = "blue",
) -> str:
    """Generate an SVG heatmap."""
    n_rows = len(matrix)
    n_cols = len(matrix[0]) if matrix else 0
    if not n_rows or not n_cols:
        return ""

    show_values = n_rows <= 10 and n_cols <= 10
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    if color_scale == "blue":
        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 30))
            g = int(255 - t * (255 - 58))
            b = int(255 - t * (255 - 95))
            return f"rgb({r},{g},{b})"
    else:  # green
        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 6))
            g = int(255 - t * (255 - 95))
            b = int(255 - t * (255 - 70))
            return f"rgb({r},{g},{b})"

    ml = max(80, max(len(l) for l in row_labels) * 7 + 10) if row_labels else 80
    mr = 20
    mt = 80
    mb = 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>')

    for j, label in enumerate(col_labels):
        cx = ml + j * cell_w + cell_w / 2
        svg.append(f'<text x="{cx:.1f}" y="{mt - 8}" text-anchor="start" font-size="10" fill="#374151" transform="rotate(-45, {cx:.1f}, {mt - 8})">{label}</text>')

    for i, row_label in enumerate(row_labels):
        ry = mt + i * cell_h + cell_h / 2
        svg.append(f'<text x="{ml - 8}" y="{ry + 4:.1f}" text-anchor="end" font-size="10" fill="#374151">{row_label}</text>')
        for j in range(n_cols):
            val = matrix[i][j]
            color = get_color(val)
            cx = ml + j * cell_w
            cy = mt + i * cell_h
            svg.append(f'<rect x="{cx:.1f}" y="{cy:.1f}" width="{cell_w:.1f}" height="{cell_h:.1f}" fill="{color}" stroke="#fff" stroke-width="1"/>')
            if show_values:
                t = (val - v_min) / v_range
                text_color = "#fff" if t > 0.55 else "#1a1a2e"
                fs = min(10, int(cell_w / 4), int(cell_h / 2.5))
                fs = max(7, fs)
                svg.append(f'<text x="{cx + cell_w / 2:.1f}" y="{cy + cell_h / 2 + 3:.1f}" text-anchor="middle" font-size="{fs}" fill="{text_color}">{val:{value_format}}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_dendrogram(
    names: list[str],
    matrix: list[list[float]],
    title: str = "",
    width: int = 700,
    height: int = 350,
    color: str = "#2563eb",
) -> str:
    """Generate an SVG dendrogram from a similarity matrix using UPGMA."""
    n = len(names)
    if n < 2:
        return ""

    dist = [[1.0 - matrix[i][j] for j in range(n)] for i in range(n)]

    clusters = [{"members": [i], "height": 0.0, "left": None, "right": None} for i in range(n)]
    active = list(range(n))

    while len(active) > 1:
        best_d = float("inf")
        bi, bj = 0, 1
        for ii in range(len(active)):
            for jj in range(ii + 1, len(active)):
                ci, cj = active[ii], active[jj]
                d = 0
                count = 0
                for mi in clusters[ci]["members"]:
                    for mj in clusters[cj]["members"]:
                        d += dist[mi][mj]
                        count += 1
                avg_d = d / count if count else 0
                if avg_d < best_d:
                    best_d = avg_d
                    bi, bj = ii, jj

        ci, cj = active[bi], active[bj]
        new_cluster = {
            "members": clusters[ci]["members"] + clusters[cj]["members"],
            "height": best_d,
            "left": clusters[ci],
            "right": clusters[cj],
        }
        clusters.append(new_cluster)
        new_idx = len(clusters) - 1
        active.pop(bj)
        active.pop(bi)
        active.append(new_idx)

    root = clusters[active[0]]

    max_label_len = max((len(n) for n in names), default=0)
    ml, mr, mt, mb = max(50, max_label_len * 5 + 10), 30, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb
    max_h = root["height"] or 1

    leaf_positions = {}
    leaf_counter = [0]

    def assign_leaves(node):
        if node["left"] is None and node["right"] is None:
            leaf_positions[node["members"][0]] = leaf_counter[0]
            leaf_counter[0] += 1
        else:
            if node["left"]:
                assign_leaves(node["left"])
            if node["right"]:
                assign_leaves(node["right"])

    assign_leaves(root)
    n_leaves = len(leaf_positions)
    leaf_spacing = cw / max(n_leaves - 1, 1)

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="13" font-weight="600" fill="#1a1a2e">{title}</text>')

    def get_x(node):
        if node["left"] is None and node["right"] is None:
            return ml + leaf_positions[node["members"][0]] * leaf_spacing
        return (get_x(node["left"]) + get_x(node["right"])) / 2

    def get_y(h):
        return mt + ch - (h / max_h) * ch

    def draw_node(node):
        if node["left"] is None and node["right"] is None:
            return
        lx = get_x(node["left"])
        rx = get_x(node["right"])
        ly = get_y(node["left"]["height"])
        ry = get_y(node["right"]["height"])
        my = get_y(node["height"])

        svg.append(f'<line x1="{lx:.1f}" y1="{ly:.1f}" x2="{lx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')
        svg.append(f'<line x1="{rx:.1f}" y1="{ry:.1f}" x2="{rx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')
        svg.append(f'<line x1="{lx:.1f}" y1="{my:.1f}" x2="{rx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')

        if node["left"]:
            draw_node(node["left"])
        if node["right"]:
            draw_node(node["right"])

    draw_node(root)

    for idx, pos in leaf_positions.items():
        x = ml + pos * leaf_spacing
        svg.append(
            f'<text x="{x:.1f}" y="{mt + ch + 14}" text-anchor="start" font-size="10" fill="#374151" '
            f'transform="rotate(40, {x:.1f}, {mt + ch + 14})">{names[idx]}</text>'
        )

    for i in range(5):
        d = max_h * i / 4
        y = get_y(d)
        svg.append(f'<text x="{ml - 4}" y="{y + 3:.1f}" text-anchor="end" font-size="9" fill="#9ca3af">{d:.3f}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    value_format: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart."""
    if not labels:
        return ""

    default_colors = ["#2563eb", "#059669", "#f59e0b", "#dc2626", "#7c3aed"]
    colors = colors or default_colors

    ml, mr, mt, mb = 60, 20, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_min = min(all_vals) if all_vals else 0
    y_max = max(all_vals) if all_vals else 1
    if y_min >= 0:
        y_min_plot = 0
        y_max_plot = y_max * 1.15 or 1
    else:
        y_range = y_max - y_min or 1
        y_min_plot = y_min - y_range * 0.05
        y_max_plot = y_max + y_range * 0.15

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - ((v - y_min_plot) / (y_max_plot - y_min_plot)) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    for i in range(6):
        y_tick = y_min_plot + (y_max_plot - y_min_plot) * i / 5
        py = sy(y_tick)
        svg.append(f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" stroke="#e9ecef" stroke-width="1"/>')
        svg.append(f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" font-size="11" fill="#6c757d">{y_tick:{value_format}}</text>')

    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" height="{max(0, bh):.1f}" fill="{color}" rx="2"/>')
            svg.append(f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" text-anchor="middle" font-size="9" fill="#1a1a2e">{val:{value_format}}</text>')
        lx = gx + n_series * bar_width / 2
        svg.append(f'<text x="{lx:.1f}" y="{mt + ch + 14}" text-anchor="start" font-size="10" fill="#6c757d" transform="rotate(35, {lx:.1f}, {mt + ch + 14})">{label}</text>')

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>')

    if n_series > 1:
        lx = ml + cw - len(series) * 110
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(f'<rect x="{lx + si * 110}" y="{mt + ch + 55}" width="12" height="12" rx="2" fill="{color}"/>')
            svg.append(f'<text x="{lx + si * 110 + 16}" y="{mt + ch + 66}" font-size="11" fill="#1a1a2e">{name}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_plddt_sparkline(values: list[float], width: int = 400, height: int = 50) -> str:
    """pLDDT sparkline with AlphaFold-style coloring."""
    if not values or len(values) < 2:
        return ""

    pad = 4
    cw = width - 2 * pad
    ch = height - 2 * pad
    seg_w = cw / len(values)

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
    ]

    for i, v in enumerate(values):
        x = pad + i * seg_w
        bar_h = (v / 100) * ch
        y = pad + ch - bar_h

        if v >= 90:
            color = "#0053d6"
        elif v >= 70:
            color = "#65cbf3"
        elif v >= 50:
            color = "#ffdb13"
        else:
            color = "#ff7d45"

        svg.append(f'<rect x="{x:.1f}" y="{y:.1f}" width="{max(seg_w, 1):.1f}" height="{bar_h:.1f}" fill="{color}"/>')

    ref_y = pad + ch - (70 / 100) * ch
    svg.append(f'<line x1="{pad}" y1="{ref_y:.1f}" x2="{pad + cw}" y2="{ref_y:.1f}" stroke="#adb5bd" stroke-width="0.5" stroke-dasharray="3,2"/>')

    svg.append("</svg>")
    return "\n".join(svg)

def _outputs_to_pdb(outputs, sequence: str) -> str:
    """Convert ESMFold outputs to PDB format string."""
    import numpy as np

    pos = outputs.positions[0]
    if pos.dim() == 4:
        pos = pos[-1]
    positions = pos.cpu().numpy()
    atom_names = ["N", "CA", "C", "O"]
    aa_3letter = {
        "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS",
        "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE",
        "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO",
        "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
    }

    pdb_lines = []
    atom_idx = 1
    for res_idx, aa in enumerate(sequence):
        res_name = aa_3letter.get(aa, "UNK")
        for atom_i, atom_name in enumerate(atom_names):
            if atom_i >= positions.shape[1]:
                break
            x, y, z = positions[res_idx, atom_i]
            if any(math.isnan(c) for c in (x, y, z)):
                continue
            pdb_lines.append(
                f"ATOM  {atom_idx:5d}  {atom_name:<3s} {res_name} A{res_idx + 1:4d}    "
                f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00  0.00           {atom_name[0]:>2s}"
            )
            atom_idx += 1
    pdb_lines.append("END")
    return "\n".join(pdb_lines)

# ------------------------------------------------------------------
# Task 1: Load gene set
# ------------------------------------------------------------------

@cpu_env.task()
async def load_genes(
    gene_set: str = "insulin",
    custom_json: str = "",
) -> flyte.io.Dir:
    """Load a set of homologous genes from different species."""
    if custom_json:
        data = json.loads(custom_json)
    elif gene_set in GENE_SETS:
        data = GENE_SETS[gene_set]
    else:
        available = ", ".join(GENE_SETS.keys())
        raise ValueError(f"Unknown gene set '{gene_set}'. Available: {available}")

    log.info(f"Loaded gene set: {data['gene_name']} - {len(data['sequences'])} species")

    out_dir = tempfile.mkdtemp(prefix="gene_compare_")
    with open(os.path.join(out_dir, "genes.json"), "w") as f:
        json.dump(data, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Score sequences with Carbon
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def score_sequences(
    genes_dir: flyte.io.Dir,
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> str:
    """Score each species' gene with Carbon-3B genomic language model.

    Returns per-species log-likelihood scores and sequence metadata.
    """
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info(f"Loading Carbon model: {model_name}")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True,
        dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    ).to(device)
    model.eval()

    genes_path = await genes_dir.download()
    with open(os.path.join(genes_path, "genes.json")) as f:
        data = json.load(f)

    species_names = list(data["sequences"].keys())
    n = len(species_names)

    scores = {}
    for i, species in enumerate(species_names):
        await flyte.report.replace.aio(_wrap_report(
            f"<h2>Carbon Scoring</h2>"
            f"<p>Scoring {species} ({i + 1}/{n})...</p>"
        ), do_flush=True)

        dna = data["sequences"][species]["dna"]
        prompt = f"<dna>{dna}"
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            output = model(**inputs, labels=inputs["input_ids"])
            loss = output.loss.item()
            ll = -loss * inputs["input_ids"].shape[1]

        protein = _translate(dna)
        scores[species] = {
            "log_likelihood": round(ll, 4),
            "loss": round(loss, 4),
            "gc_content": round(_gc_content(dna), 4),
            "length": len(dna),
            "protein": protein,
            "protein_length": len(protein),
            "common_name": data["sequences"][species]["common_name"],
        }
        log.info(f"  {species} ({data['sequences'][species]['common_name']}): LL={ll:.2f}, GC={_gc_content(dna):.1%}")

    # Report
    html_parts = [
        f"<h2>{data['gene_name']} - Carbon Scoring</h2>",
        f'<div class="note">{data["description"]}</div>',
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{n}</div><div class="label">Species</div></div>',
        f'<div class="stat"><div class="value">{data["gene_name"]}</div><div class="label">Gene</div></div>',
        f'<div class="stat"><div class="value">{model_name.split("/")[-1]}</div><div class="label">Model</div></div>',
        "</div>",
    ]

    html_parts.append(
        "<table><tr><th>Species</th><th>Scientific Name</th><th>DNA Length</th>"
        "<th>GC%</th><th>Protein Length</th><th>Carbon LL</th></tr>"
    )
    for species in species_names:
        s = scores[species]
        html_parts.append(
            f'<tr><td><b>{species}</b></td><td><i>{s["common_name"]}</i></td>'
            f'<td>{s["length"]}bp</td><td>{s["gc_content"]:.1%}</td>'
            f'<td>{s["protein_length"]}aa</td>'
            f'<td>{s["log_likelihood"]:.2f}</td></tr>'
        )
    html_parts.append("</table>")

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species_names,
        {"Log-Likelihood": [scores[s]["log_likelihood"] for s in species_names]},
        title="Carbon Log-Likelihood per Species",
        value_format=".1f",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    result = {
        "gene_name": data["gene_name"],
        "description": data["description"],
        "species": species_names,
        "scores": scores,
    }
    return json.dumps(result)

# ------------------------------------------------------------------
# Task 3: Align sequences and compute similarity
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def align_and_compare(
    scores_json: str,
    genes_dir: flyte.io.Dir,
) -> str:
    """Align sequences with Needleman-Wunsch and compute pairwise identity.

    Translates DNA to protein, builds DNA and protein identity matrices,
    and generates phylogenetic trees from sequence divergence.
    """
    scores_data = json.loads(scores_json)
    species_names = scores_data["species"]
    scores = scores_data["scores"]
    gene_name = scores_data["gene_name"]
    n = len(species_names)

    genes_path = await genes_dir.download()
    with open(os.path.join(genes_path, "genes.json")) as f:
        data = json.load(f)

    await flyte.report.replace.aio(_wrap_report(
        f"<h2>{gene_name} - Sequence Alignment</h2>"
        f"<p>Aligning {n} species with Needleman-Wunsch...</p>"
    ), do_flush=True)

    # Pairwise DNA identity matrix
    identity_matrix = []
    for sp1 in species_names:
        row = []
        for sp2 in species_names:
            dna1 = data["sequences"][sp1]["dna"]
            dna2 = data["sequences"][sp2]["dna"]
            identity = _sequence_identity(dna1, dna2)
            row.append(round(identity, 4))
        identity_matrix.append(row)

    # Pairwise protein identity matrix
    protein_matrix = []
    for sp1 in species_names:
        row = []
        for sp2 in species_names:
            identity = _sequence_identity(scores[sp1]["protein"], scores[sp2]["protein"])
            row.append(round(identity, 4))
        protein_matrix.append(row)

    # Average pairwise identities (exclude diagonal)
    dna_pairs = [identity_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    prot_pairs = [protein_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    avg_dna = sum(dna_pairs) / len(dna_pairs) if dna_pairs else 0
    avg_prot = sum(prot_pairs) / len(prot_pairs) if prot_pairs else 0

    # Most/least similar pair
    best_pair = max(range(len(dna_pairs)), key=lambda k: dna_pairs[k])
    worst_pair = min(range(len(dna_pairs)), key=lambda k: dna_pairs[k])
    pair_indices = [(i, j) for i in range(n) for j in range(i + 1, n)]
    best_sp = f"{species_names[pair_indices[best_pair][0]]}-{species_names[pair_indices[best_pair][1]]}"
    worst_sp = f"{species_names[pair_indices[worst_pair][0]]}-{species_names[pair_indices[worst_pair][1]]}"

    # Report
    html_parts = [
        f"<h2>{gene_name} - Sequence Alignment</h2>",
        f'<div class="note">Pairwise alignment using Needleman-Wunsch (match=2, mismatch=-1, gap=-2). '
        f"Identity is computed as matches / aligned length from the optimal global alignment.</div>",
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{n}</div><div class="label">Species Aligned</div></div>',
        f'<div class="stat"><div class="value">{n * (n - 1) // 2}</div><div class="label">Pairwise Alignments</div></div>',
        f'<div class="stat"><div class="value">{avg_dna:.0%}</div><div class="label">Avg DNA Identity</div></div>',
        f'<div class="stat"><div class="value">{avg_prot:.0%}</div><div class="label">Avg Protein Identity</div></div>',
        f'<div class="stat"><div class="value">{best_sp}</div><div class="label">Most Similar</div></div>',
        f'<div class="stat"><div class="value">{worst_sp}</div><div class="label">Most Divergent</div></div>',
        "</div>",
    ]

    # DNA identity heatmap
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        identity_matrix, species_names, species_names,
        title="Pairwise DNA Sequence Identity (%)",
        value_format=".0%",
    ))
    html_parts.append("</div>")

    # Protein identity heatmap
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        protein_matrix, species_names, species_names,
        title="Pairwise Protein Sequence Identity (%)",
        value_format=".0%",
        color_scale="green",
    ))
    html_parts.append("</div>")

    # DNA phylogenetic tree
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species_names, identity_matrix,
        title=f"{gene_name} - Phylogenetic Tree (DNA Identity)",
    ))
    html_parts.append("</div>")

    # Protein phylogenetic tree
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species_names, protein_matrix,
        title=f"{gene_name} - Phylogenetic Tree (Protein Identity)",
        color="#059669",
    ))
    html_parts.append("</div>")

    # DNA vs Protein conservation comparison
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species_names,
        {
            "DNA vs Human": [identity_matrix[0][j] for j in range(n)],
            "Protein vs Human": [protein_matrix[0][j] for j in range(n)],
        },
        title=f"Conservation vs {species_names[0]} (DNA and Protein)",
        value_format=".0%",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    result = {
        "gene_name": gene_name,
        "description": scores_data["description"],
        "species": species_names,
        "scores": scores,
        "dna_identity_matrix": identity_matrix,
        "protein_identity_matrix": protein_matrix,
    }
    return json.dumps(result)

# ------------------------------------------------------------------
# Task 4: Fold proteins with ESMFold
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def fold_proteins(
    comparison_json: str,
    max_length: int = 400,
) -> str:
    """Fold each species' translated protein with ESMFold for 3D comparison.

    Returns PDB strings and pLDDT confidence scores for each species.
    """
    import torch
    import numpy as np
    from transformers import AutoTokenizer, EsmForProteinFolding

    comparison = json.loads(comparison_json)
    species_names = comparison["species"]
    scores = comparison["scores"]

    log.info("Loading ESMFold model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
    model = model.to(device)
    model.eval()

    structure_data = {}
    n = len(species_names)

    for idx, species in enumerate(species_names):
        protein = scores[species]["protein"]

        if len(protein) > max_length:
            log.info(f"Skipping {species} ({len(protein)} aa > {max_length} max)")
            continue

        log.info(f"ESMFold [{idx + 1}/{n}]: {species} ({len(protein)} aa)")
        await flyte.report.replace.aio(_wrap_report(
            f"<h2>ESMFold - 3D Structure Prediction</h2>"
            f"<p>Folding {species} ({idx + 1}/{n}): {len(protein)} residues...</p>"
        ), do_flush=True)

        inputs = tokenizer(protein, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        pdb_str = _outputs_to_pdb(outputs, protein)

        plddt_raw = outputs.plddt[0].cpu().numpy()
        if plddt_raw.ndim == 2:
            plddt_raw = plddt_raw[-1]
        plddt = plddt_raw.flatten()[:len(protein)]
        if plddt.max() <= 1.0:
            plddt = plddt * 100
        plddt_mean = float(np.mean(plddt))

        structure_data[species] = {
            "pdb_str": pdb_str,
            "plddt_mean": round(plddt_mean, 1),
            "plddt_per_residue": [round(float(v), 1) for v in plddt[:len(protein)]],
            "protein_length": len(protein),
        }
        log.info(f"  → mean pLDDT: {plddt_mean:.1f}")

    # Report with 3D viewers
    n_folded = len(structure_data)
    avg_plddt = sum(d["plddt_mean"] for d in structure_data.values()) / n_folded if n_folded else 0

    threeDmol_script = '<script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>'

    stats_html = f"""
    <h2>ESMFold - Cross-Species Structure Comparison</h2>
    <div class="note">
      <b>ESMFold</b> predicts 3D structure directly from amino acid sequence.
      Comparing structures across species reveals which parts of the protein are
      structurally conserved (functional core) vs divergent (surface loops, species-specific adaptations).
    </div>
    <div class="stat-grid">
      <div class="stat"><div class="value">{n_folded}</div><div class="label">Structures</div></div>
      <div class="stat"><div class="value">{avg_plddt:.1f}</div><div class="label">Avg pLDDT</div></div>
      <div class="stat"><div class="value">{comparison['gene_name']}</div><div class="label">Gene</div></div>
    </div>
    """

    viewers_html = '<div class="structure-grid">'
    for species, sdata in structure_data.items():
        plddt_val = sdata["plddt_mean"]
        common = scores[species]["common_name"]

        if plddt_val >= 90:
            badge = '<span class="badge badge-success">Very High</span>'
        elif plddt_val >= 70:
            badge = '<span class="badge badge-info">Confident</span>'
        elif plddt_val >= 50:
            badge = '<span class="badge badge-warning">Low</span>'
        else:
            badge = '<span class="badge badge-danger">Disordered</span>'

        plddt_sparkline = _make_plddt_sparkline(sdata["plddt_per_residue"], width=300)
        pdb_escaped = sdata["pdb_str"].replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$")
        viewer_id = f"viewer_{hash(species) & 0xFFFFFF:06x}"

        viewers_html += f"""
        <div class="card" style="margin:0;">
          <h3 style="margin-top:0;">{species}
            <span style="font-size:0.7em;color:#6c757d;">({sdata['protein_length']} aa)</span>
            {badge}
          </h3>
          <p style="font-size:0.85em;color:#6c757d;margin:2px 0 8px;"><i>{common}</i></p>
          <div id="{viewer_id}" style="width:100%;max-width:320px;height:280px;border:1px solid #dbeafe;border-radius:8px;position:relative;"></div>
          <div style="margin-top:8px;">
            <b>Mean pLDDT:</b> {plddt_val:.1f} / 100
            <div style="margin-top:4px;">{plddt_sparkline}</div>
            <div style="font-size:0.75em;color:#9ca3af;margin-top:2px;">
              <span style="color:#0053d6;">&block; &gt;90</span>
              <span style="color:#65cbf3;">&block; 70-90</span>
              <span style="color:#ffdb13;">&block; 50-70</span>
              <span style="color:#ff7d45;">&block; &lt;50</span>
            </div>
          </div>
        </div>
        <script>
        (function() {{
          var pdb = `{pdb_escaped}`;
          function initViewer() {{
            if (typeof $3Dmol === 'undefined') {{ setTimeout(initViewer, 200); return; }}
            var viewer = $3Dmol.createViewer(document.getElementById("{viewer_id}"), {{backgroundColor: "white"}});
            viewer.addModel(pdb, "pdb");
            viewer.setStyle({{}}, {{cartoon: {{color: "spectrum"}}}});
            viewer.zoomTo();
            viewer.render();
            viewer.spin("y", 1);
          }}
          initViewer();
        }})();
        </script>
        """

    viewers_html += "</div>"

    # pLDDT comparison bar chart
    plddt_chart = _make_bar_chart(
        list(structure_data.keys()),
        {"Mean pLDDT": [d["plddt_mean"] for d in structure_data.values()]},
        title="Structure Confidence Comparison (pLDDT)",
        value_format=".1f",
        colors=["#0053d6"],
    )

    report_html = f"""
    {threeDmol_script}
    {stats_html}
    {viewers_html}
    <div class="chart-container">{plddt_chart}</div>
    """

    await flyte.report.replace.aio(_wrap_report(report_html), do_flush=True)

    return json.dumps(structure_data)

# ------------------------------------------------------------------
# Task 5: Generate summary
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def generate_summary(
    comparison_json: str,
    structures_json: str,
) -> str:
    """Generate comprehensive cross-species summary."""
    comparison = json.loads(comparison_json)
    structures = json.loads(structures_json)

    species = comparison["species"]
    scores = comparison["scores"]
    gene_name = comparison["gene_name"]
    dna_matrix = comparison["dna_identity_matrix"]
    protein_matrix = comparison["protein_identity_matrix"]

    html_parts = [
        f"<h2>{gene_name} - Cross-Species Evolution Summary</h2>",
        f'<div class="note">{comparison["description"]}</div>',
    ]

    # Key metrics
    # Average pairwise identity (exclude diagonal)
    n = len(species)
    dna_pairs = [dna_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    protein_pairs = [protein_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    avg_dna_id = sum(dna_pairs) / len(dna_pairs) if dna_pairs else 0
    avg_protein_id = sum(protein_pairs) / len(protein_pairs) if protein_pairs else 0
    avg_plddt = sum(d["plddt_mean"] for d in structures.values()) / len(structures) if structures else 0

    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{n}</div><div class="label">Species</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_dna_id:.0%}</div><div class="label">Avg DNA Identity</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_protein_id:.0%}</div><div class="label">Avg Protein Identity</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_plddt:.1f}</div><div class="label">Avg pLDDT</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{len(structures)}</div><div class="label">Structures Folded</div></div>')
    html_parts.append("</div>")

    # Full comparison table
    html_parts.append("<h3>Per-Species Detail</h3>")
    html_parts.append(
        "<table><tr><th>Species</th><th>Scientific Name</th><th>DNA (bp)</th>"
        "<th>Protein (aa)</th><th>GC%</th><th>Carbon LL</th><th>pLDDT</th></tr>"
    )
    for sp in species:
        s = scores[sp]
        plddt = structures.get(sp, {}).get("plddt_mean", "N/A")
        plddt_str = f"{plddt:.1f}" if isinstance(plddt, float) else plddt
        html_parts.append(
            f'<tr><td><b>{sp}</b></td><td><i>{s["common_name"]}</i></td>'
            f'<td>{s["length"]}</td><td>{s["protein_length"]}</td>'
            f'<td>{s["gc_content"]:.1%}</td><td>{s["log_likelihood"]:.2f}</td>'
            f'<td>{plddt_str}</td></tr>'
        )
    html_parts.append("</table>")

    # GC content comparison
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species,
        {"GC Content": [scores[s]["gc_content"] for s in species]},
        title="GC Content Across Species",
        value_format=".2f",
    ))
    html_parts.append("</div>")

    # DNA phylogenetic tree
    html_parts.append("<h3>Phylogenetic Relationships</h3>")
    html_parts.append(
        '<div class="note">'
        "Trees built from pairwise sequence identity using UPGMA clustering. "
        "Species that diverged more recently cluster together. DNA and protein trees "
        "may differ when synonymous mutations dominate."
        "</div>"
    )

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species, dna_matrix,
        title=f"{gene_name} - DNA Phylogenetic Tree",
    ))
    html_parts.append("</div>")

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species, protein_matrix,
        title=f"{gene_name} - Protein Phylogenetic Tree",
        color="#059669",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    summary = {
        "gene_name": gene_name,
        "n_species": n,
        "avg_dna_identity": round(avg_dna_id, 4),
        "avg_protein_identity": round(avg_protein_id, 4),
        "avg_plddt": round(avg_plddt, 1),
        "n_structures": len(structures),
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Pipeline orchestrator
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    gene_set: str = "insulin",
    model_name: str = "HuggingFaceBio/Carbon-3B",
    custom_json: str = "",
) -> tuple[str, str]:
    """
    End-to-end cross-species gene comparison pipeline.

    Returns (comparison JSON, structures JSON).

    1. Load homologous gene sequences across species
    2. Score with Carbon genomic language model
    3. Align sequences and compute pairwise similarity
    4. Fold translated proteins with ESMFold
    5. Generate comprehensive summary with phylogenetic trees
    """
    log.info(f"Starting cross-species gene comparison pipeline (gene_set={gene_set})...")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = [
            "Load Genes",
            "Carbon Scoring",
            "Sequence Alignment",
            "ESMFold Structures",
            "Generate Summary",
        ]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#2563eb;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#2563eb;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>Cross-Species Gene Comparison</h2>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    # Stage 1
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Loading homologous gene sequences...")),
        do_flush=True,
    )
    genes_dir = await load_genes(gene_set=gene_set, custom_json=custom_json)

    # Stage 2
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Scoring sequences with Carbon...")),
        do_flush=True,
    )
    scores_json = await score_sequences(genes_dir=genes_dir, model_name=model_name)

    # Stage 3
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Aligning sequences with Needleman-Wunsch...")),
        do_flush=True,
    )
    comparison_json = await align_and_compare(scores_json=scores_json, genes_dir=genes_dir)

    # Stage 4
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Folding proteins with ESMFold...")),
        do_flush=True,
    )
    structures_json = await fold_proteins(comparison_json=comparison_json)

    # Stage 5
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(5, "Generating summary report...")),
        do_flush=True,
    )
    summary_json = await generate_summary(
        comparison_json=comparison_json,
        structures_json=structures_json,
    )

    # Final report
    summary = json.loads(summary_json)
    comparison = json.loads(comparison_json)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{summary['gene_name']}</div><div class="label">Gene</div></div>
      <div class="stat"><div class="value">{summary['n_species']}</div><div class="label">Species</div></div>
      <div class="stat"><div class="value">{summary['avg_dna_identity']:.0%}</div><div class="label">Avg DNA Identity</div></div>
      <div class="stat"><div class="value">{summary['avg_protein_identity']:.0%}</div><div class="label">Avg Protein Identity</div></div>
      <div class="stat"><div class="value">{summary['avg_plddt']:.1f}</div><div class="label">Avg pLDDT</div></div>
      <div class="stat"><div class="value">{summary['n_structures']}</div><div class="label">3D Structures</div></div>
    </div>
    <div class="card">
      <b>Gene:</b> {summary['gene_name']} |
      <b>Species:</b> {', '.join(comparison['species'])} |
      <b>Model:</b> {model_name}
    </div>
    <div class="note">
      All 4 pipeline stages completed. View individual task reports for DNA/protein
      identity heatmaps, phylogenetic trees, interactive 3D protein structures with
      pLDDT confidence, Carbon log-likelihood scores, and evolutionary analysis.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)
    log.info("Pipeline complete.")
    return comparison_json, structures_json

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_gene_comparison/genomic_gene_comparison.py*

Dependencies are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# ///
```

## Orchestrate the pipeline

The top-level `pipeline` task chains four stages: load genes, Carbon scoring, sequence alignment, ESMFold folding, and a cross-species summary report.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="genomic-gene-comparison", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="genomic-gene-comparison-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="32Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="genomic-gene-comparison-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Homologous gene sets - same gene across species
# ------------------------------------------------------------------
# Full-length coding sequences from NCBI RefSeq (stop codon excluded).

GENE_SETS = {
    "insulin": {
        "gene_name": "Insulin",
        "description": "Insulin regulates blood sugar in all vertebrates. Highly conserved across 500M+ years of evolution - even fish insulin can lower blood sugar in humans. Comparing across species reveals which regions are functionally essential (conserved) vs free to drift.",
        "sequences": {
            "Human": {
                "dna": "ATGGCCCTGTGGATGCGCCTCCTGCCCCTGCTGGCGCTGCTGGCCCTCTGGGGACCTGACCCAGCCGCAGCCTTTGTGAACCAACACCTGTGCGGCTCACACCTGGTGGAAGCTCTCTACCTAGTGTGCGGGGAACGAGGCTTCTTCTACACACCCAAGACCCGCCGGGAGGCAGAGGACCTGCAGGTGGGGCAGGTGGAGCTGGGCGGGGGCCCTGGTGCAGGCAGCCTGCAGCCCTTGGCCCTGGAGGGGTCCCTGCAGAAGCGTGGCATTGTGGAACAATGCTGTACCAGCATCTGCTCCCTCTACCAGCTGGAGAACTACTGCAAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGGCCCTGTGGATGCGCTTCCTGCCCCTGCTGGCCCTGCTCTTCCTCTGGGAGTCCCACCCCACCCAGGCTTTTGTCAAGCAGCACCTTTGTGGTTCCCACCTGGTGGAGGCTCTCTACCTGGTGTGTGGGGAGCGTGGCTTCTTCTACACACCCATGTCCCGCCGTGAAGTGGAGGACCCACAAGTGGCACAACTGGAGCTGGGTGGAGGCCCGGGAGCAGGTGACCTTCAGACCTTGGCACTGGAGGTGGCCCAGCAGAAGCGTGGCATTGTAGATCAGTGCTGCACCAGCATCTGCTCCCTCTACCAGCTGGAGAACTACTGCAAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGCTCTCTGGATCCGATCACTGCCTCTTCTGGCTCTCCTTGTCTTTTCTGGCCCTGGAACCAGCTATGCAGCTGCCAACCAGCACCTCTGTGGCTCCCACTTGGTGGAGGCTCTCTACCTGGTGTGTGGAGAGCGTGGCTTCTTCTACTCCCCCAAAGCCCGACGGGATGTCGAGCAGCCCCTAGTGAGCAGTCCCTTGCGTGGCGAGGCAGGAGTGCTGCCTTTCCAGCAGGAGGAATACGAGAAAGTCAAGCGAGGGATTGTTGAGCAATGCTGCCATAACACGTGTTCCCTCTACCAACTGGAGAACTACTGCAAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGCAGTGTGGCTTCAGGCTGGTGCTCTGTTGGTCCTGTTGGTCGTGTCCAGTGTAAGCACTAACCCAGGCACACCGCAGCACCTGTGTGGATCTCATCTGGTCGATGCCCTTTATCTGGTCTGTGGCCCAACAGGCTTCTTCTACAACCCCAAGAGAGACGTTGAGCCCCTTCTGGGTTTCCTTCCTCCTAAATCTGCCCAGGAAACTGAGGTGGCTGACTTTGCATTTAAAGATCATGCCGAGCTGATAAGGAAGAGAGGCATTGTAGAGCAGTGCTGCCACAAACCCTGCAGCATCTTTGAGCTGCAGAACTACTGTAAC",
                "common_name": "Danio rerio",
            },
            "Frog": {
                "dna": "ATGGCTCTATGGATGCAGTGTCTGCCCCTGGTTCTTGTCCTCTTTTTCTCTACACCCAACACCGAAGCTCTAGTTAACCAGCACTTGTGTGGGTCTCACCTGGTAGAAGCCCTGTACTTAGTATGTGGGGATCGAGGCTTCTTCTACTACCCTAAGGTCAAACGGGACATGGAACAAGCACTTGTCAGTGGACCCCAGGATAATGAGTTGGATGGAATGCAGCTCCAGCCTCAGGAGTATCAGAAAATGAAGAGGGGGATTGTGGAGCAATGTTGCCACAGCACATGTTCTCTCTTCCAGCTGGAGAGTTACTGCAAC",
                "common_name": "Xenopus laevis",
            },
            "Cow": {
                "dna": "ATGGCCCTGTGGACACGCCTGGCGCCCCTGCTGGCCCTGCTGGCGCTCTGGGCCCCCGCCCCGGCCCGCGCCTTCGTCAACCAGCATCTGTGTGGCTCCCACCTGGTGGAGGCGCTGTACCTGGTGTGCGGAGAGCGCGGCTTCTTCTACACGCCCAAGGCCCGCCGGGAGGTGGAGGGCCCCCAGGTGGGGGCGCTGGAGCTGGCCGGAGGCCCGGGCGCGGGCGGCCTGGAGGGGCCCCCGCAGAAGCGTGGCATCGTGGAGCAGTGCTGTGCCAGCGTCTGCTCGCTCTACCAGCTGGAGAACTACTGTAAC",
                "common_name": "Bos taurus",
            },
        },
    },
    "hemoglobin": {
        "gene_name": "Hemoglobin Beta",
        "description": "Beta-globin carries oxygen from lungs to tissues. The most studied gene in molecular evolution - sequence differences power the 'molecular clock' hypothesis. Sickle cell mutation (E6V) in humans shows how a single base change creates devastating disease.",
        "sequences": {
            "Human": {
                "dna": "ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTGTGGGGCAAGGTGAACGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTGGTCTACCCTTGGACCCAGAGGTTCTTTGAGTCCTTTGGGGATCTGTCCACTCCTGATGCTGTTATGGGCAACCCTAAGGTGAAGGCTCATGGCAAGAAAGTGCTCGGTGCCTTTAGTGATGGCCTGGCTCACCTGGACAACCTCAAGGGCACCTTTGCCACACTGAGTGAGCTGCACTGTGACAAGCTGCACGTGGATCCTGAGAACTTCAGGCTCCTGGGCAACGTGCTGGTCTGTGTGCTGGCCCATCACTTTGGCAAAGAATTCACCCCACCAGTGCAGGCTGCCTATCAGAAAGTGGTGGCTGGTGTGGCTAATGCCCTGGCCCACAAGTATCAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGGTGCACCTGACTGATGCTGAGAAGGCTGCTGTCTCTGGCCTGTGGGGAAAGGTGAACGCCGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTTGTCTACCCTTGGACCCAGCGGTACTTTGATAGCTTTGGAGACCTATCCTCTGCCTCTGCTATCATGGGTAATGCCAAAGTGAAGGCCCATGGCAAGAAAGTGATAACTGCCTTTAACGATGGCCTGAATCACTTGGACAGCCTCAAGGGCACCTTTGCCAGCCTCAGTGAGCTCCACTGTGACAAGCTGCATGTGGATCCTGAGAACTTCAGGCTCCTGGGCAATATGATCGTGATTGTGCTGGGCCACCACCTGGGCAAGGATTTCACCCCCGCTGCACAGGCTGCCTTCCAGAAGGTGGTGGCTGGAGTGGCTGCTGCCCTGGCTCACAAGTACCAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGTGCACTGGACTGCTGAGGAGAAGCAGCTCATCACCGGCCTCTGGGGCAAGGTCAATGTGGCCGAATGTGGGGCTGAAGCCCTGGCCAGGCTGCTGATCGTCTACCCCTGGACCCAGAGGTTCTTTGCGTCCTTTGGGAACCTCTCCAGCCCCACTGCCATCCTTGGCAACCCCATGGTCCGCGCCCATGGCAAGAAAGTGCTCACCTCCTTTGGGGATGCTGTGAAGAACCTGGACAACATCAAGAACACCTTCTCCCAACTGTCCGAACTGCATTGTGACAAGCTGCATGTGGACCCCGAGAACTTCAGGCTCCTGGGTGACATCCTCATCATTGTCCTGGCCGCCCACTTCAGCAAGGACTTCACTCCTGAATGCCAGGCTGCCTGGCAGAAGCTGGTCCGCGTGGTGGCCCATGCCCTGGCTCGCAAGTACCAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGTTGAGTGGACAGATGCCGAGCGCACAGCCATCCTTGGCCTGTGGGGAAAGCTCAATATCGATGAAATCGGACCTCAGGCCCTATCCAGATGTCTGATCGTGTATCCCTGGACTCAGAGATATTTCGCCACATTCGGCAACCTGTCAAGCCCCGCTGCGATCATGGGTAACCCCAAAGTGGCAGCTCATGGGAGGACTGTGATGGGAGGTCTTGAGAGAGCCATCAAGAACATGGACAACGTCAAGAACACCTATGCCGCCCTCAGTGTGATGCACTCTGAGAAACTGCATGTGGATCCCGACAACTTCAGGCTTCTCGCTGATTGCATCACCGTTTGCGCTGCCATGAAGTTCGGCCAAGCTGGTTTCAATGCTGATGTCCAGGAGGCCTGGCAGAAGTTTCTGGCTGTGGTCGTTTCTGCTCTGTGCAGACAGTACCAC",
                "common_name": "Danio rerio",
            },
            "Frog": {
                "dna": "ATGGTTCATTGGACAGCTGAAGAGAAGGCCGCCATCACCTCTGTGTGGCAGGAGGTCAACCAGGAGCAAGATGGCCATGATGCACTCACAAGGCTGCTGGTTGTGTACCCCTGGACCCAGAGATACTTCAGCAGTTTTGGAAATCTCGGTAATGCCACAGCTATTGCTGGAAATGTCAAGGTGCGTGCCCATGGCAAGAAGGTTCTTTCAGCTGTTGGTGATGCCATCGCCCATCTTGACAACGTGAAGGGAACTCTCCATGACCTCAGTGTGGTCCACGCCTTCAAGCTCTATGTGGATCCTGAGAACTTCAAGCGTCTTGGTGAAGTTCTGGTCATTGTCTTGGCTTCCAAACTGGGATCAGCCTTTACTCCTCAAGTCCAGGGAGCCTGGGAGAAATTTGTTGCTGTTCTGGTTGATGCCCTCAGCCAAGGATACAAC",
                "common_name": "Xenopus laevis",
            },
            "Cow": {
                "dna": "ATGCTGACTGCTGAGGAGAAGGCTGCCGTCACCGCCTTTTGGGGCAAGGTGAAAGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTTGTCTACCCCTGGACTCAGAGGTTCTTTGAGTCCTTTGGGGACTTGTCCACTGCTGATGCTGTTATGAACAACCCTAAGGTGAAGGCCCATGGCAAGAAGGTGCTAGATTCCTTTAGTAATGGCATGAAGCATCTCGATGACCTCAAGGGCACCTTTGCTGCGCTGAGTGAGCTGCACTGTGATAAGCTGCATGTGGATCCTGAGAACTTCAAGCTCCTGGGCAACGTGCTAGTGGTTGTGCTGGCTCGCAATTTTGGCAAGGAATTCACCCCGGTGCTGCAGGCTGACTTTCAGAAGGTGGTGGCTGGTGTGGCCAATGCCCTGGCCCACAGATATCAT",
                "common_name": "Bos taurus",
            },
        },
    },
    "p53": {
        "gene_name": "p53 (TP53)",
        "description": "The 'guardian of the genome' - p53 detects DNA damage and triggers repair or cell death. Mutated in >50% of human cancers. Elephants have 20 copies of p53 (humans have 1), which may explain their extremely low cancer rates despite their size (Peto's paradox).",
        "sequences": {
            "Human": {
                "dna": "ATGGAGGAGCCGCAGTCAGATCCTAGCGTCGAGCCCCCTCTGAGTCAGGAAACATTTTCAGACCTATGGAAACTACTTCCTGAAAACAACGTTCTGTCCCCCTTGCCGTCCCAAGCAATGGATGATTTGATGCTGTCCCCGGACGATATTGAACAATGGTTCACTGAAGACCCAGGTCCAGATGAAGCTCCCAGAATGCCAGAGGCTGCTCCCCCCGTGGCCCCTGCACCAGCAGCTCCTACACCGGCGGCCCCTGCACCAGCCCCCTCCTGGCCCCTGTCATCTTCTGTCCCTTCCCAGAAAACCTACCAGGGCAGCTACGGTTTCCGTCTGGGCTTCTTGCATTCTGGGACAGCCAAGTCTGTGACTTGCACGTACTCCCCTGCCCTCAACAAGATGTTTTGCCAACTGGCCAAGACCTGCCCTGTGCAGCTGTGGGTTGATTCCACACCCCCGCCCGGCACCCGCGTCCGCGCCATGGCCATCTACAAGCAGTCACAGCACATGACGGAGGTTGTGAGGCGCTGCCCCCACCATGAGCGCTGCTCAGATAGCGATGGTCTGGCCCCTCCTCAGCATCTTATCCGAGTGGAAGGAAATTTGCGTGTGGAGTATTTGGATGACAGAAACACTTTTCGACATAGTGTGGTGGTGCCCTATGAGCCGCCTGAGGTTGGCTCTGACTGTACCACCATCCACTACAACTACATGTGTAACAGTTCCTGCATGGGCGGCATGAACCGGAGGCCCATCCTCACCATCATCACACTGGAAGACTCCAGTGGTAATCTACTGGGACGGAACAGCTTTGAGGTGCGTGTTTGTGCCTGTCCTGGGAGAGACCGGCGCACAGAGGAAGAGAATCTCCGCAAGAAAGGGGAGCCTCACCACGAGCTGCCCCCAGGGAGCACTAAGCGAGCACTGCCCAACAACACCAGCTCCTCTCCCCAGCCAAAGAAGAAACCACTGGATGGAGAATATTTCACCCTTCAGATCCGTGGGCGTGAGCGCTTCGAGATGTTCCGAGAGCTGAATGAGGCCTTGGAACTCAAGGATGCCCAGGCTGGGAAGGAGCCAGGGGGGAGCAGGGCTCACTCCAGCCACCTGAAGTCCAAAAAGGGTCAGTCTACCTCCCGCCATAAAAAACTCATGTTCAAGACAGAAGGGCCTGACTCAGAC",
                "common_name": "Homo sapiens",
            },
            "Mouse": {
                "dna": "ATGACTGCCATGGAGGAGTCACAGTCGGATATCAGCCTCGAGCTCCCTCTGAGCCAGGAGACATTTTCAGGCTTATGGAAACTACTTCCTCCAGAAGATATCCTGCCATCACCTCACTGCATGGACGATCTGTTGCTGCCCCAGGATGTTGAGGAGTTTTTTGAAGGCCCAAGTGAAGCCCTCCGAGTGTCAGGAGCTCCTGCAGCACAGGACCCTGTCACCGAGACCCCTGGGCCAGTGGCCCCTGCCCCAGCCACTCCATGGCCCCTGTCATCTTTTGTCCCTTCTCAAAAAACTTACCAGGGCAACTATGGCTTCCACCTGGGCTTCCTGCAGTCTGGGACAGCCAAGTCTGTTATGTGCACGTACTCTCCTCCCCTCAATAAGCTATTCTGCCAGCTGGCGAAGACGTGCCCTGTGCAGTTGTGGGTCAGCGCCACACCTCCAGCTGGGAGCCGTGTCCGCGCCATGGCCATCTACAAGAAGTCACAGCACATGACGGAGGTCGTGAGACGCTGCCCCCACCATGAGCGCTGCTCCGATGGTGATGGCCTGGCTCCTCCCCAGCATCTTATCCGGGTGGAAGGAAATTTGTATCCCGAGTATCTGGAAGACAGGCAGACTTTTCGCCACAGCGTGGTGGTACCTTATGAGCCACCCGAGGCCGGCTCTGAGTATACCACCATCCACTACAAGTACATGTGTAATAGCTCCTGCATGGGGGGCATGAACCGCCGACCTATCCTTACCATCATCACACTGGAAGACTCCAGTGGGAACCTTCTGGGACGGGACAGCTTTGAGGTTCGTGTTTGTGCCTGCCCTGGGAGAGACCGCCGTACAGAAGAAGAAAATTTCCGCAAAAAGGAAGTCCTTTGCCCTGAACTGCCCCCAGGGAGCGCAAAGAGAGCGCTGCCCACCTGCACAAGCGCCTCTCCCCCGCAAAAGAAAAAACCACTTGATGGAGAGTATTTCACCCTCAAGATCCGCGGGCGTAAACGCTTCGAGATGTTCCGGGAGCTGAATGAGGCCTTAGAGTTAAAGGATGCCCATGCTACAGAGGAGTCTGGAGACAGCAGGGCTCACTCCAGCTACCTGAAGACCAAGAAGGGCCAGTCTACTTCCCGCCATAAAAAAACAATGGTCAAGAAAGTGGGGCCTGACTCAGAC",
                "common_name": "Mus musculus",
            },
            "Chicken": {
                "dna": "ATGGCGGAGGAGATGGAACCATTGCTGGAACCCACTGAGGTCTTCATGGACCTCTGGAGCATGCTCCCCTATAGCATGCAACAGCTGCCCCTCCCTGAGGATCACAGCAACTGGCAGGAGCTGAGCCCCCTGGAACCCAGCGACCCCCCCCCACCACCGCCACCACCACCTCTGCCATTGGCCGCCGCCGCCCCCCCCCCATTAAACCCCCCCACCCCCCCCCGCGCTGCCCCCTCCCCGGTGGTCCCATCCACGGAGGATTATGGGGGGGACTTCGACTTCCGGGTGGGGTTCGTGGAGGCGGGCACAGCCAAATCGGTCACCTGCACTTACTCCCCGGTGCTGAATAAGGTCTATTGCCGCCTGGCCAAGCCGTGCCCGGTGCAGGTGAGGGTGGGGGTGGCGCCCCCCCCCGGTTCCTCCCTCCGCGCCGTGGCCGTCTATAAGAAATCAGAGCACGTGGCCGAAGTGGTGCGGCGCTGCCCCCACCACGAGCGCTGCGGGGGGGGCACCGACGGCCTGGCCCCCGCACAGCACCTCATCCGGGTGGAGGGGAACCCCCAGGCGCGTTACCACGACGACGAGACCACCAAACGGCACAGCGTCGTCGTCCCCTATGAGCCCCCCGAGGTGGGCTCTGACTGTACCACGGTGCTGTACAACTTCATGTGCAACAGTTCCTGCATGGGGGGGATGAACCGCCGCCCCATCCTCACCATCCTTACACTGGAGGGGCCGGGGGGGCAGCTGTTGGGGCGGCGCTGCTTCGAGGTGCGCGTGTGCGCATGTCCGGGGAGGGACCGCAAGATCGAGGAGGAGAACTTCCGCAAGAGGGGCGGGGCCGGGGGCGTGGCTAAGCGAGCCATGTCGCCCCCAACCGAAGCCCCCGAGCCCCCCAAGAAGCGCGTGCTGAACCCCGACAATGAGATATTCTACCTGCAGGTGCGCGGGCGCCGCCGCTATGAGATGCTGAAGGAGATCAATGAGGCGCTGCAGCTCGCCGAGGGGGGGTCCGCACCGCGGCCTTCCAAAGGCCGCCGTGTGAAGGTGGAGGGACCCCAACCCAGCTGCGGGAAGAAACTGCTGCAAAAAGGCTCGGAC",
                "common_name": "Gallus gallus",
            },
            "Zebrafish": {
                "dna": "ATGGCGCAAAACGACAGCCAAGAGTTCGCGGAGCTCTGGGAGAAGAATTTGATTATTCAGCCCCCAGGTGGTGGCTCTTGCTGGGACATCATTAATGATGAGGAGTACTTGCCGGGATCGTTTGACCCCAATTTTTTTGAAAATGTGCTTGAAGAACAGCCTCAGCCATCCACTCTCCCACCAACATCCACTGTTCCGGAGACAAGCGACTATCCCGGCGATCATGGATTTAGGCTCAGGTTCCCGCAGTCTGGCACAGCAAAATCTGTAACTTGCACTTATTCACCGGACCTGAATAAACTCTTCTGTCAGCTGGCAAAAACTTGCCCCGTTCAAATGGTGGTGGACGTTGCCCCTCCACAGGGCTCCGTGGTTCGAGCCACTGCCATCTATAAGAAGTCCGAGCATGTGGCTGAAGTGGTCCGCAGATGCCCCCATCATGAGCGAACCCCGGATGGAGATAACTTGGCGCCTGCTGGTCATTTGATAAGAGTGGAGGGCAATCAGCGAGCAAATTACAGGGAAGATAACATCACTTTAAGGCATAGTGTTTTTGTCCCATATGAAGCACCACAGCTTGGTGCTGAATGGACAACTGTGCTACTAAACTACATGTGCAATAGCAGCTGCATGGGGGGGATGAACCGCAGGCCCATCCTCACAATCATCACTCTGGAGACTCAGGAAGGTCAGTTGCTGGGCCGGAGGTCTTTTGAGGTGCGTGTGTGTGCATGTCCAGGCAGAGACAGGAAAACTGAGGAGAGCAACTTCAAGAAAGACCAAGAGACCAAAACCATGGCCAAAACCACCACTGGGACCAAACGTAGTTTGGTGAAAGAATCTTCTTCAGCTACATTACGACCTGAGGGGAGCAAAAAGGCCAAGGGCTCCAGCAGCGATGAGGAGATCTTTACCCTGCAGGTGAGGGGCAGGGAGCGTTATGAAATTTTAAAGAAATTGAACGACAGTCTGGAGTTAAGTGATGTGGTGCCTGCCTCAGATGCTGAAAAGTATCGTCAGAAATTCATGACAAAAAACAAAAAAGAGAATCGTGAATCATCTGAGCCCAAACAGGGAAAGAAGCTGATGGTGAAGGACGAAGGAAGAAGCGACTCTGAT",
                "common_name": "Danio rerio",
            },
            "Elephant": {
                "dna": "ATGGAGGAGCCCCAGTCAGATCTCAGCACCGAGCTCCCTCTGAGTCAAGAGACGTTTTCATACTTATGGGAACTCCTTCCTGAGAATCCGGTTCTGTCCCCCACACTACCCCCGGCAGTGGAGGTCATGGACGATCTGCTACTCTCAGAAGACACTGCAAACTGGCTAGAAAGCCAAGTTGAGGCTCAGGGAATGTCCACAACCCCTGCACCAGCCACCCCTACACCGGTGGCCCCCGCACCAGCCACCTCCTGGACCCTGTCATCTTCCGTCCCTTCCCAAAAGACCTACCCTGGCACCTATGGTTTCCGTCTGGGCTTCCTACATTCTGGGACAGCCAAGTCCGTCACCTGCACGTACTCCCCTGACCTTAACAAGCTGTTTTGCCAGCTGGCAAAAACCTGCCCAGTGCAGCTGTGGGTCGCCTCACCACCCCCGCCCGGCACCCGTGTTCGCACCATGGCCATCTACAAGAAGTCAGAGCATATGACGGAGGTCGTCAAGCGCTGCCCCCACCATGAGCGCTGCTCTGACTCTAGCGATGGCCTGGCCCCTCCTCAGCACCTCATCCGGGTGGAAGGAAACCTGCGTGCTGAGTATCTGGAGGACAGCATCACTCTCCGACACAGTGTGGTGGTGCCCTACGAGCCGCCCGAGGTTGGGTCTGACTGTACCACCATCCACTTCAACTTCATGTGTAACAGCTCCTGCATGGGGGGCATGAACCGGCGGCCCATCCTCACCATCATCACACTGGAAGACTCCAGTGGTAATCTGCTGGGACGTAACAGCTTTGAGGTGCGCATTTGTGCCTGTCCTGGAAGAGACAGACGTACAGAAGAAGAAAATTTCCACAAGAAGGGAGAGCCTTGCCCAGAGCCGCCACCCCCTGGGAGGAGCACTAAGCGAGCACTGCCCACCAACACCAGCTCCTCTACCCAGCCAAAGAAGAAGCCACTGGATGAAGAATATTTCACCCTTCAGATCCGTGGGCGTGAACGCTTCAAGATGTTCCTAGAGCTAAATGAGGCCTTGGAGCTGAAGGATGCCCAGGCTGGGAAGGAGCCAGAGGGGAGCCGGGCTCACTCCAGCCCTTCGAAGTCTAAGAAGGGACAGTCTACCTCCCGCCATAAAAAACCAATGTTCAAGAGAGAGGGACCTGACTCAGAC",
                "common_name": "Loxodonta africana",
            },
            "Dog": {
                "dna": "ATGGAGGAGTCGCAGTCAGAGCTCAATATCGACCCCCCTCTGAGCCAGGAGACATTTTCAGAATTGTGGAACCTGCTTCCTGAAAACAATGTTCTGTCTTCGGAGCTGTGCCCAGCAGTGGATGAGCTGCTGCTCCCAGAGAGCGTCGTGAACTGGCTAGACGAAGACTCAGATGATGCTCCCAGGATGCCAGCCACTTCTGCCCCCACAGCCCCTGGACCGGCCCCCTCGTGGCCCCTATCATCCTCTGTCCCTTCCCCGAAGACCTACCCTGGCACCTATGGGTTCCGTTTGGGGTTCCTGCATTCCGGGACAGCCAAGTCTGTTACTTGGACGTACTCCCCTCTCCTCAACAAGTTGTTTTGCCAGCTGGCGAAGACCTGCCCCGTGCAGCTGTGGGTCAGCTCCCCACCCCCACCCAATACCTGCGTCCGCGCTATGGCCATCTATAAGAAGTCGGAGTTCGTGACCGAGGTTGTGCGGCGCTGCCCCCACCATGAACGCTGCTCTGACAGTAGTGACGGTCTTGCCCCTCCTCAGCATCTCATCCGAGTGGAAGGAAATTTGCGGGCCAAGTACCTGGACGACAGAAACACTTTTCGACACAGTGTGGTGGTGCCTTATGAGCCACCCGAGGTTGGCTCTGACTATACCACCATCCACTACAACTACATGTGTAACAGTTCCTGCATGGGAGGCATGAACCGGCGGCCCATCCTCACTATCATCACCCTGGAAGACTCCAGTGGAAACGTGCTGGGACGCAACAGCTTTGAGGTACGCGTTTGTGCCTGTCCCGGGAGAGACCGCCGGACTGAGGAGGAGAATTTCCACAAGAAGGGGGAGCCTTGTCCTGAGCCACCCCCCGGGAGTACCAAGCGAGCACTGCCTCCCAGCACCAGCTCCTCTCCCCCGCAAAAGAAGAAGCCACTAGATGGAGAATATTTCACCCTTCAGATCCGTGGGCGTGAACGCTATGAGATGTTCAGGAATCTGAATGAAGCCTTGGAGCTGAAGGATGCCCAGAGTGGAAAGGAGCCAGGGGGAAGCAGGGCTCACTCCAGCCACCTGAAGGCAAAGAAGGGGCAATCTACCTCTCGCCATAAAAAACTGATGTTCAAGAGAGAAGGGCTTGACTCAGAC",
                "common_name": "Canis lupus familiaris",
            },
        },
    },
}

# Standard genetic code
CODON_TABLE = {
    "TTT": "F", "TTC": "F", "TTA": "L", "TTG": "L", "CTT": "L", "CTC": "L",
    "CTA": "L", "CTG": "L", "ATT": "I", "ATC": "I", "ATA": "I", "ATG": "M",
    "GTT": "V", "GTC": "V", "GTA": "V", "GTG": "V", "TCT": "S", "TCC": "S",
    "TCA": "S", "TCG": "S", "CCT": "P", "CCC": "P", "CCA": "P", "CCG": "P",
    "ACT": "T", "ACC": "T", "ACA": "T", "ACG": "T", "GCT": "A", "GCC": "A",
    "GCA": "A", "GCG": "A", "TAT": "Y", "TAC": "Y", "TAA": "*", "TAG": "*",
    "CAT": "H", "CAC": "H", "CAA": "Q", "CAG": "Q", "AAT": "N", "AAC": "N",
    "AAA": "K", "AAG": "K", "GAT": "D", "GAC": "D", "GAA": "E", "GAG": "E",
    "TGT": "C", "TGC": "C", "TGA": "*", "TGG": "W", "CGT": "R", "CGC": "R",
    "CGA": "R", "CGG": "R", "AGT": "S", "AGC": "S", "AGA": "R", "AGG": "R",
    "GGT": "G", "GGC": "G", "GGA": "G", "GGG": "G",
}

BASE_COLORS = {"A": "#2ecc71", "T": "#e74c3c", "G": "#f39c12", "C": "#3498db"}

def _translate(dna: str) -> str:
    """Translate DNA to protein in reading frame 0."""
    dna = dna.upper()
    protein = []
    for i in range(0, len(dna) - 2, 3):
        codon = dna[i:i + 3]
        aa = CODON_TABLE.get(codon, "X")
        if aa == "*":
            break
        protein.append(aa)
    return "".join(protein)

def _gc_content(seq: str) -> float:
    if not seq:
        return 0.0
    return sum(1 for b in seq.upper() if b in "GC") / len(seq)

def _sequence_identity(seq1: str, seq2: str, match: int = 2, mismatch: int = -1, gap: int = -2) -> float:
    """Percent identity via Needleman-Wunsch global alignment."""
    if not seq1 or not seq2:
        return 0.0
    n, m = len(seq1), len(seq2)

    # Build score matrix
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        dp[i][0] = dp[i - 1][0] + gap
    for j in range(1, m + 1):
        dp[0][j] = dp[0][j - 1] + gap
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            s = match if seq1[i - 1] == seq2[j - 1] else mismatch
            dp[i][j] = max(dp[i - 1][j - 1] + s, dp[i - 1][j] + gap, dp[i][j - 1] + gap)

    # Traceback to count matches and alignment length
    i, j = n, m
    matches = 0
    aligned = 0
    while i > 0 or j > 0:
        if i > 0 and j > 0:
            s = match if seq1[i - 1] == seq2[j - 1] else mismatch
            if dp[i][j] == dp[i - 1][j - 1] + s:
                if seq1[i - 1] == seq2[j - 1]:
                    matches += 1
                aligned += 1
                i -= 1
                j -= 1
                continue
        if i > 0 and dp[i][j] == dp[i - 1][j] + gap:
            aligned += 1
            i -= 1
        else:
            aligned += 1
            j -= 1

    return matches / aligned if aligned else 0.0

# ------------------------------------------------------------------
# Report styling
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #1e3a5f; border-bottom: 2px solid #2563eb; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #1e40af; margin-top: 20px; }
  .report .card { background: #eff6ff; border: 1px solid #bfdbfe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #dbeafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #1e3a5f; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #1e3a5f; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dbeafe; }
  .report tr:nth-child(even) { background: #eff6ff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-warning { background: #fef3c7; color: #92400e; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #dbeafe; color: #1e40af; }
  .report .chart-container { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #eff6ff; border-left: 4px solid #2563eb; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .structure-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(340px, 1fr)); gap: 16px; margin: 12px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    width: int = 600,
    height: int = 500,
    value_format: str = ".1f",
    color_scale: str = "blue",
) -> str:
    """Generate an SVG heatmap."""
    n_rows = len(matrix)
    n_cols = len(matrix[0]) if matrix else 0
    if not n_rows or not n_cols:
        return ""

    show_values = n_rows <= 10 and n_cols <= 10
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    if color_scale == "blue":
        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 30))
            g = int(255 - t * (255 - 58))
            b = int(255 - t * (255 - 95))
            return f"rgb({r},{g},{b})"
    else:  # green
        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 6))
            g = int(255 - t * (255 - 95))
            b = int(255 - t * (255 - 70))
            return f"rgb({r},{g},{b})"

    ml = max(80, max(len(l) for l in row_labels) * 7 + 10) if row_labels else 80
    mr = 20
    mt = 80
    mb = 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>')

    for j, label in enumerate(col_labels):
        cx = ml + j * cell_w + cell_w / 2
        svg.append(f'<text x="{cx:.1f}" y="{mt - 8}" text-anchor="start" font-size="10" fill="#374151" transform="rotate(-45, {cx:.1f}, {mt - 8})">{label}</text>')

    for i, row_label in enumerate(row_labels):
        ry = mt + i * cell_h + cell_h / 2
        svg.append(f'<text x="{ml - 8}" y="{ry + 4:.1f}" text-anchor="end" font-size="10" fill="#374151">{row_label}</text>')
        for j in range(n_cols):
            val = matrix[i][j]
            color = get_color(val)
            cx = ml + j * cell_w
            cy = mt + i * cell_h
            svg.append(f'<rect x="{cx:.1f}" y="{cy:.1f}" width="{cell_w:.1f}" height="{cell_h:.1f}" fill="{color}" stroke="#fff" stroke-width="1"/>')
            if show_values:
                t = (val - v_min) / v_range
                text_color = "#fff" if t > 0.55 else "#1a1a2e"
                fs = min(10, int(cell_w / 4), int(cell_h / 2.5))
                fs = max(7, fs)
                svg.append(f'<text x="{cx + cell_w / 2:.1f}" y="{cy + cell_h / 2 + 3:.1f}" text-anchor="middle" font-size="{fs}" fill="{text_color}">{val:{value_format}}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_dendrogram(
    names: list[str],
    matrix: list[list[float]],
    title: str = "",
    width: int = 700,
    height: int = 350,
    color: str = "#2563eb",
) -> str:
    """Generate an SVG dendrogram from a similarity matrix using UPGMA."""
    n = len(names)
    if n < 2:
        return ""

    dist = [[1.0 - matrix[i][j] for j in range(n)] for i in range(n)]

    clusters = [{"members": [i], "height": 0.0, "left": None, "right": None} for i in range(n)]
    active = list(range(n))

    while len(active) > 1:
        best_d = float("inf")
        bi, bj = 0, 1
        for ii in range(len(active)):
            for jj in range(ii + 1, len(active)):
                ci, cj = active[ii], active[jj]
                d = 0
                count = 0
                for mi in clusters[ci]["members"]:
                    for mj in clusters[cj]["members"]:
                        d += dist[mi][mj]
                        count += 1
                avg_d = d / count if count else 0
                if avg_d < best_d:
                    best_d = avg_d
                    bi, bj = ii, jj

        ci, cj = active[bi], active[bj]
        new_cluster = {
            "members": clusters[ci]["members"] + clusters[cj]["members"],
            "height": best_d,
            "left": clusters[ci],
            "right": clusters[cj],
        }
        clusters.append(new_cluster)
        new_idx = len(clusters) - 1
        active.pop(bj)
        active.pop(bi)
        active.append(new_idx)

    root = clusters[active[0]]

    max_label_len = max((len(n) for n in names), default=0)
    ml, mr, mt, mb = max(50, max_label_len * 5 + 10), 30, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb
    max_h = root["height"] or 1

    leaf_positions = {}
    leaf_counter = [0]

    def assign_leaves(node):
        if node["left"] is None and node["right"] is None:
            leaf_positions[node["members"][0]] = leaf_counter[0]
            leaf_counter[0] += 1
        else:
            if node["left"]:
                assign_leaves(node["left"])
            if node["right"]:
                assign_leaves(node["right"])

    assign_leaves(root)
    n_leaves = len(leaf_positions)
    leaf_spacing = cw / max(n_leaves - 1, 1)

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="13" font-weight="600" fill="#1a1a2e">{title}</text>')

    def get_x(node):
        if node["left"] is None and node["right"] is None:
            return ml + leaf_positions[node["members"][0]] * leaf_spacing
        return (get_x(node["left"]) + get_x(node["right"])) / 2

    def get_y(h):
        return mt + ch - (h / max_h) * ch

    def draw_node(node):
        if node["left"] is None and node["right"] is None:
            return
        lx = get_x(node["left"])
        rx = get_x(node["right"])
        ly = get_y(node["left"]["height"])
        ry = get_y(node["right"]["height"])
        my = get_y(node["height"])

        svg.append(f'<line x1="{lx:.1f}" y1="{ly:.1f}" x2="{lx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')
        svg.append(f'<line x1="{rx:.1f}" y1="{ry:.1f}" x2="{rx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')
        svg.append(f'<line x1="{lx:.1f}" y1="{my:.1f}" x2="{rx:.1f}" y2="{my:.1f}" stroke="{color}" stroke-width="2"/>')

        if node["left"]:
            draw_node(node["left"])
        if node["right"]:
            draw_node(node["right"])

    draw_node(root)

    for idx, pos in leaf_positions.items():
        x = ml + pos * leaf_spacing
        svg.append(
            f'<text x="{x:.1f}" y="{mt + ch + 14}" text-anchor="start" font-size="10" fill="#374151" '
            f'transform="rotate(40, {x:.1f}, {mt + ch + 14})">{names[idx]}</text>'
        )

    for i in range(5):
        d = max_h * i / 4
        y = get_y(d)
        svg.append(f'<text x="{ml - 4}" y="{y + 3:.1f}" text-anchor="end" font-size="9" fill="#9ca3af">{d:.3f}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    value_format: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart."""
    if not labels:
        return ""

    default_colors = ["#2563eb", "#059669", "#f59e0b", "#dc2626", "#7c3aed"]
    colors = colors or default_colors

    ml, mr, mt, mb = 60, 20, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_min = min(all_vals) if all_vals else 0
    y_max = max(all_vals) if all_vals else 1
    if y_min >= 0:
        y_min_plot = 0
        y_max_plot = y_max * 1.15 or 1
    else:
        y_range = y_max - y_min or 1
        y_min_plot = y_min - y_range * 0.05
        y_max_plot = y_max + y_range * 0.15

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - ((v - y_min_plot) / (y_max_plot - y_min_plot)) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    for i in range(6):
        y_tick = y_min_plot + (y_max_plot - y_min_plot) * i / 5
        py = sy(y_tick)
        svg.append(f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" stroke="#e9ecef" stroke-width="1"/>')
        svg.append(f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" font-size="11" fill="#6c757d">{y_tick:{value_format}}</text>')

    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" height="{max(0, bh):.1f}" fill="{color}" rx="2"/>')
            svg.append(f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" text-anchor="middle" font-size="9" fill="#1a1a2e">{val:{value_format}}</text>')
        lx = gx + n_series * bar_width / 2
        svg.append(f'<text x="{lx:.1f}" y="{mt + ch + 14}" text-anchor="start" font-size="10" fill="#6c757d" transform="rotate(35, {lx:.1f}, {mt + ch + 14})">{label}</text>')

    if title:
        svg.append(f'<text x="{width / 2}" y="22" text-anchor="middle" font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>')

    if n_series > 1:
        lx = ml + cw - len(series) * 110
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(f'<rect x="{lx + si * 110}" y="{mt + ch + 55}" width="12" height="12" rx="2" fill="{color}"/>')
            svg.append(f'<text x="{lx + si * 110 + 16}" y="{mt + ch + 66}" font-size="11" fill="#1a1a2e">{name}</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_plddt_sparkline(values: list[float], width: int = 400, height: int = 50) -> str:
    """pLDDT sparkline with AlphaFold-style coloring."""
    if not values or len(values) < 2:
        return ""

    pad = 4
    cw = width - 2 * pad
    ch = height - 2 * pad
    seg_w = cw / len(values)

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
    ]

    for i, v in enumerate(values):
        x = pad + i * seg_w
        bar_h = (v / 100) * ch
        y = pad + ch - bar_h

        if v >= 90:
            color = "#0053d6"
        elif v >= 70:
            color = "#65cbf3"
        elif v >= 50:
            color = "#ffdb13"
        else:
            color = "#ff7d45"

        svg.append(f'<rect x="{x:.1f}" y="{y:.1f}" width="{max(seg_w, 1):.1f}" height="{bar_h:.1f}" fill="{color}"/>')

    ref_y = pad + ch - (70 / 100) * ch
    svg.append(f'<line x1="{pad}" y1="{ref_y:.1f}" x2="{pad + cw}" y2="{ref_y:.1f}" stroke="#adb5bd" stroke-width="0.5" stroke-dasharray="3,2"/>')

    svg.append("</svg>")
    return "\n".join(svg)

def _outputs_to_pdb(outputs, sequence: str) -> str:
    """Convert ESMFold outputs to PDB format string."""
    import numpy as np

    pos = outputs.positions[0]
    if pos.dim() == 4:
        pos = pos[-1]
    positions = pos.cpu().numpy()
    atom_names = ["N", "CA", "C", "O"]
    aa_3letter = {
        "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS",
        "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE",
        "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO",
        "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
    }

    pdb_lines = []
    atom_idx = 1
    for res_idx, aa in enumerate(sequence):
        res_name = aa_3letter.get(aa, "UNK")
        for atom_i, atom_name in enumerate(atom_names):
            if atom_i >= positions.shape[1]:
                break
            x, y, z = positions[res_idx, atom_i]
            if any(math.isnan(c) for c in (x, y, z)):
                continue
            pdb_lines.append(
                f"ATOM  {atom_idx:5d}  {atom_name:<3s} {res_name} A{res_idx + 1:4d}    "
                f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00  0.00           {atom_name[0]:>2s}"
            )
            atom_idx += 1
    pdb_lines.append("END")
    return "\n".join(pdb_lines)

# ------------------------------------------------------------------
# Task 1: Load gene set
# ------------------------------------------------------------------

@cpu_env.task()
async def load_genes(
    gene_set: str = "insulin",
    custom_json: str = "",
) -> flyte.io.Dir:
    """Load a set of homologous genes from different species."""
    if custom_json:
        data = json.loads(custom_json)
    elif gene_set in GENE_SETS:
        data = GENE_SETS[gene_set]
    else:
        available = ", ".join(GENE_SETS.keys())
        raise ValueError(f"Unknown gene set '{gene_set}'. Available: {available}")

    log.info(f"Loaded gene set: {data['gene_name']} - {len(data['sequences'])} species")

    out_dir = tempfile.mkdtemp(prefix="gene_compare_")
    with open(os.path.join(out_dir, "genes.json"), "w") as f:
        json.dump(data, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Score sequences with Carbon
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def score_sequences(
    genes_dir: flyte.io.Dir,
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> str:
    """Score each species' gene with Carbon-3B genomic language model.

    Returns per-species log-likelihood scores and sequence metadata.
    """
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info(f"Loading Carbon model: {model_name}")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True,
        dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    ).to(device)
    model.eval()

    genes_path = await genes_dir.download()
    with open(os.path.join(genes_path, "genes.json")) as f:
        data = json.load(f)

    species_names = list(data["sequences"].keys())
    n = len(species_names)

    scores = {}
    for i, species in enumerate(species_names):
        await flyte.report.replace.aio(_wrap_report(
            f"<h2>Carbon Scoring</h2>"
            f"<p>Scoring {species} ({i + 1}/{n})...</p>"
        ), do_flush=True)

        dna = data["sequences"][species]["dna"]
        prompt = f"<dna>{dna}"
        inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            output = model(**inputs, labels=inputs["input_ids"])
            loss = output.loss.item()
            ll = -loss * inputs["input_ids"].shape[1]

        protein = _translate(dna)
        scores[species] = {
            "log_likelihood": round(ll, 4),
            "loss": round(loss, 4),
            "gc_content": round(_gc_content(dna), 4),
            "length": len(dna),
            "protein": protein,
            "protein_length": len(protein),
            "common_name": data["sequences"][species]["common_name"],
        }
        log.info(f"  {species} ({data['sequences'][species]['common_name']}): LL={ll:.2f}, GC={_gc_content(dna):.1%}")

    # Report
    html_parts = [
        f"<h2>{data['gene_name']} - Carbon Scoring</h2>",
        f'<div class="note">{data["description"]}</div>',
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{n}</div><div class="label">Species</div></div>',
        f'<div class="stat"><div class="value">{data["gene_name"]}</div><div class="label">Gene</div></div>',
        f'<div class="stat"><div class="value">{model_name.split("/")[-1]}</div><div class="label">Model</div></div>',
        "</div>",
    ]

    html_parts.append(
        "<table><tr><th>Species</th><th>Scientific Name</th><th>DNA Length</th>"
        "<th>GC%</th><th>Protein Length</th><th>Carbon LL</th></tr>"
    )
    for species in species_names:
        s = scores[species]
        html_parts.append(
            f'<tr><td><b>{species}</b></td><td><i>{s["common_name"]}</i></td>'
            f'<td>{s["length"]}bp</td><td>{s["gc_content"]:.1%}</td>'
            f'<td>{s["protein_length"]}aa</td>'
            f'<td>{s["log_likelihood"]:.2f}</td></tr>'
        )
    html_parts.append("</table>")

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species_names,
        {"Log-Likelihood": [scores[s]["log_likelihood"] for s in species_names]},
        title="Carbon Log-Likelihood per Species",
        value_format=".1f",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    result = {
        "gene_name": data["gene_name"],
        "description": data["description"],
        "species": species_names,
        "scores": scores,
    }
    return json.dumps(result)

# ------------------------------------------------------------------
# Task 3: Align sequences and compute similarity
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def align_and_compare(
    scores_json: str,
    genes_dir: flyte.io.Dir,
) -> str:
    """Align sequences with Needleman-Wunsch and compute pairwise identity.

    Translates DNA to protein, builds DNA and protein identity matrices,
    and generates phylogenetic trees from sequence divergence.
    """
    scores_data = json.loads(scores_json)
    species_names = scores_data["species"]
    scores = scores_data["scores"]
    gene_name = scores_data["gene_name"]
    n = len(species_names)

    genes_path = await genes_dir.download()
    with open(os.path.join(genes_path, "genes.json")) as f:
        data = json.load(f)

    await flyte.report.replace.aio(_wrap_report(
        f"<h2>{gene_name} - Sequence Alignment</h2>"
        f"<p>Aligning {n} species with Needleman-Wunsch...</p>"
    ), do_flush=True)

    # Pairwise DNA identity matrix
    identity_matrix = []
    for sp1 in species_names:
        row = []
        for sp2 in species_names:
            dna1 = data["sequences"][sp1]["dna"]
            dna2 = data["sequences"][sp2]["dna"]
            identity = _sequence_identity(dna1, dna2)
            row.append(round(identity, 4))
        identity_matrix.append(row)

    # Pairwise protein identity matrix
    protein_matrix = []
    for sp1 in species_names:
        row = []
        for sp2 in species_names:
            identity = _sequence_identity(scores[sp1]["protein"], scores[sp2]["protein"])
            row.append(round(identity, 4))
        protein_matrix.append(row)

    # Average pairwise identities (exclude diagonal)
    dna_pairs = [identity_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    prot_pairs = [protein_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    avg_dna = sum(dna_pairs) / len(dna_pairs) if dna_pairs else 0
    avg_prot = sum(prot_pairs) / len(prot_pairs) if prot_pairs else 0

    # Most/least similar pair
    best_pair = max(range(len(dna_pairs)), key=lambda k: dna_pairs[k])
    worst_pair = min(range(len(dna_pairs)), key=lambda k: dna_pairs[k])
    pair_indices = [(i, j) for i in range(n) for j in range(i + 1, n)]
    best_sp = f"{species_names[pair_indices[best_pair][0]]}-{species_names[pair_indices[best_pair][1]]}"
    worst_sp = f"{species_names[pair_indices[worst_pair][0]]}-{species_names[pair_indices[worst_pair][1]]}"

    # Report
    html_parts = [
        f"<h2>{gene_name} - Sequence Alignment</h2>",
        f'<div class="note">Pairwise alignment using Needleman-Wunsch (match=2, mismatch=-1, gap=-2). '
        f"Identity is computed as matches / aligned length from the optimal global alignment.</div>",
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{n}</div><div class="label">Species Aligned</div></div>',
        f'<div class="stat"><div class="value">{n * (n - 1) // 2}</div><div class="label">Pairwise Alignments</div></div>',
        f'<div class="stat"><div class="value">{avg_dna:.0%}</div><div class="label">Avg DNA Identity</div></div>',
        f'<div class="stat"><div class="value">{avg_prot:.0%}</div><div class="label">Avg Protein Identity</div></div>',
        f'<div class="stat"><div class="value">{best_sp}</div><div class="label">Most Similar</div></div>',
        f'<div class="stat"><div class="value">{worst_sp}</div><div class="label">Most Divergent</div></div>',
        "</div>",
    ]

    # DNA identity heatmap
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        identity_matrix, species_names, species_names,
        title="Pairwise DNA Sequence Identity (%)",
        value_format=".0%",
    ))
    html_parts.append("</div>")

    # Protein identity heatmap
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        protein_matrix, species_names, species_names,
        title="Pairwise Protein Sequence Identity (%)",
        value_format=".0%",
        color_scale="green",
    ))
    html_parts.append("</div>")

    # DNA phylogenetic tree
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species_names, identity_matrix,
        title=f"{gene_name} - Phylogenetic Tree (DNA Identity)",
    ))
    html_parts.append("</div>")

    # Protein phylogenetic tree
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species_names, protein_matrix,
        title=f"{gene_name} - Phylogenetic Tree (Protein Identity)",
        color="#059669",
    ))
    html_parts.append("</div>")

    # DNA vs Protein conservation comparison
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species_names,
        {
            "DNA vs Human": [identity_matrix[0][j] for j in range(n)],
            "Protein vs Human": [protein_matrix[0][j] for j in range(n)],
        },
        title=f"Conservation vs {species_names[0]} (DNA and Protein)",
        value_format=".0%",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    result = {
        "gene_name": gene_name,
        "description": scores_data["description"],
        "species": species_names,
        "scores": scores,
        "dna_identity_matrix": identity_matrix,
        "protein_identity_matrix": protein_matrix,
    }
    return json.dumps(result)

# ------------------------------------------------------------------
# Task 4: Fold proteins with ESMFold
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def fold_proteins(
    comparison_json: str,
    max_length: int = 400,
) -> str:
    """Fold each species' translated protein with ESMFold for 3D comparison.

    Returns PDB strings and pLDDT confidence scores for each species.
    """
    import torch
    import numpy as np
    from transformers import AutoTokenizer, EsmForProteinFolding

    comparison = json.loads(comparison_json)
    species_names = comparison["species"]
    scores = comparison["scores"]

    log.info("Loading ESMFold model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
    model = model.to(device)
    model.eval()

    structure_data = {}
    n = len(species_names)

    for idx, species in enumerate(species_names):
        protein = scores[species]["protein"]

        if len(protein) > max_length:
            log.info(f"Skipping {species} ({len(protein)} aa > {max_length} max)")
            continue

        log.info(f"ESMFold [{idx + 1}/{n}]: {species} ({len(protein)} aa)")
        await flyte.report.replace.aio(_wrap_report(
            f"<h2>ESMFold - 3D Structure Prediction</h2>"
            f"<p>Folding {species} ({idx + 1}/{n}): {len(protein)} residues...</p>"
        ), do_flush=True)

        inputs = tokenizer(protein, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        pdb_str = _outputs_to_pdb(outputs, protein)

        plddt_raw = outputs.plddt[0].cpu().numpy()
        if plddt_raw.ndim == 2:
            plddt_raw = plddt_raw[-1]
        plddt = plddt_raw.flatten()[:len(protein)]
        if plddt.max() <= 1.0:
            plddt = plddt * 100
        plddt_mean = float(np.mean(plddt))

        structure_data[species] = {
            "pdb_str": pdb_str,
            "plddt_mean": round(plddt_mean, 1),
            "plddt_per_residue": [round(float(v), 1) for v in plddt[:len(protein)]],
            "protein_length": len(protein),
        }
        log.info(f"  → mean pLDDT: {plddt_mean:.1f}")

    # Report with 3D viewers
    n_folded = len(structure_data)
    avg_plddt = sum(d["plddt_mean"] for d in structure_data.values()) / n_folded if n_folded else 0

    threeDmol_script = '<script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>'

    stats_html = f"""
    <h2>ESMFold - Cross-Species Structure Comparison</h2>
    <div class="note">
      <b>ESMFold</b> predicts 3D structure directly from amino acid sequence.
      Comparing structures across species reveals which parts of the protein are
      structurally conserved (functional core) vs divergent (surface loops, species-specific adaptations).
    </div>
    <div class="stat-grid">
      <div class="stat"><div class="value">{n_folded}</div><div class="label">Structures</div></div>
      <div class="stat"><div class="value">{avg_plddt:.1f}</div><div class="label">Avg pLDDT</div></div>
      <div class="stat"><div class="value">{comparison['gene_name']}</div><div class="label">Gene</div></div>
    </div>
    """

    viewers_html = '<div class="structure-grid">'
    for species, sdata in structure_data.items():
        plddt_val = sdata["plddt_mean"]
        common = scores[species]["common_name"]

        if plddt_val >= 90:
            badge = '<span class="badge badge-success">Very High</span>'
        elif plddt_val >= 70:
            badge = '<span class="badge badge-info">Confident</span>'
        elif plddt_val >= 50:
            badge = '<span class="badge badge-warning">Low</span>'
        else:
            badge = '<span class="badge badge-danger">Disordered</span>'

        plddt_sparkline = _make_plddt_sparkline(sdata["plddt_per_residue"], width=300)
        pdb_escaped = sdata["pdb_str"].replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$")
        viewer_id = f"viewer_{hash(species) & 0xFFFFFF:06x}"

        viewers_html += f"""
        <div class="card" style="margin:0;">
          <h3 style="margin-top:0;">{species}
            <span style="font-size:0.7em;color:#6c757d;">({sdata['protein_length']} aa)</span>
            {badge}
          </h3>
          <p style="font-size:0.85em;color:#6c757d;margin:2px 0 8px;"><i>{common}</i></p>
          <div id="{viewer_id}" style="width:100%;max-width:320px;height:280px;border:1px solid #dbeafe;border-radius:8px;position:relative;"></div>
          <div style="margin-top:8px;">
            <b>Mean pLDDT:</b> {plddt_val:.1f} / 100
            <div style="margin-top:4px;">{plddt_sparkline}</div>
            <div style="font-size:0.75em;color:#9ca3af;margin-top:2px;">
              <span style="color:#0053d6;">&block; &gt;90</span>
              <span style="color:#65cbf3;">&block; 70-90</span>
              <span style="color:#ffdb13;">&block; 50-70</span>
              <span style="color:#ff7d45;">&block; &lt;50</span>
            </div>
          </div>
        </div>
        <script>
        (function() {{
          var pdb = `{pdb_escaped}`;
          function initViewer() {{
            if (typeof $3Dmol === 'undefined') {{ setTimeout(initViewer, 200); return; }}
            var viewer = $3Dmol.createViewer(document.getElementById("{viewer_id}"), {{backgroundColor: "white"}});
            viewer.addModel(pdb, "pdb");
            viewer.setStyle({{}}, {{cartoon: {{color: "spectrum"}}}});
            viewer.zoomTo();
            viewer.render();
            viewer.spin("y", 1);
          }}
          initViewer();
        }})();
        </script>
        """

    viewers_html += "</div>"

    # pLDDT comparison bar chart
    plddt_chart = _make_bar_chart(
        list(structure_data.keys()),
        {"Mean pLDDT": [d["plddt_mean"] for d in structure_data.values()]},
        title="Structure Confidence Comparison (pLDDT)",
        value_format=".1f",
        colors=["#0053d6"],
    )

    report_html = f"""
    {threeDmol_script}
    {stats_html}
    {viewers_html}
    <div class="chart-container">{plddt_chart}</div>
    """

    await flyte.report.replace.aio(_wrap_report(report_html), do_flush=True)

    return json.dumps(structure_data)

# ------------------------------------------------------------------
# Task 5: Generate summary
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def generate_summary(
    comparison_json: str,
    structures_json: str,
) -> str:
    """Generate comprehensive cross-species summary."""
    comparison = json.loads(comparison_json)
    structures = json.loads(structures_json)

    species = comparison["species"]
    scores = comparison["scores"]
    gene_name = comparison["gene_name"]
    dna_matrix = comparison["dna_identity_matrix"]
    protein_matrix = comparison["protein_identity_matrix"]

    html_parts = [
        f"<h2>{gene_name} - Cross-Species Evolution Summary</h2>",
        f'<div class="note">{comparison["description"]}</div>',
    ]

    # Key metrics
    # Average pairwise identity (exclude diagonal)
    n = len(species)
    dna_pairs = [dna_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    protein_pairs = [protein_matrix[i][j] for i in range(n) for j in range(i + 1, n)]
    avg_dna_id = sum(dna_pairs) / len(dna_pairs) if dna_pairs else 0
    avg_protein_id = sum(protein_pairs) / len(protein_pairs) if protein_pairs else 0
    avg_plddt = sum(d["plddt_mean"] for d in structures.values()) / len(structures) if structures else 0

    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{n}</div><div class="label">Species</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_dna_id:.0%}</div><div class="label">Avg DNA Identity</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_protein_id:.0%}</div><div class="label">Avg Protein Identity</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{avg_plddt:.1f}</div><div class="label">Avg pLDDT</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{len(structures)}</div><div class="label">Structures Folded</div></div>')
    html_parts.append("</div>")

    # Full comparison table
    html_parts.append("<h3>Per-Species Detail</h3>")
    html_parts.append(
        "<table><tr><th>Species</th><th>Scientific Name</th><th>DNA (bp)</th>"
        "<th>Protein (aa)</th><th>GC%</th><th>Carbon LL</th><th>pLDDT</th></tr>"
    )
    for sp in species:
        s = scores[sp]
        plddt = structures.get(sp, {}).get("plddt_mean", "N/A")
        plddt_str = f"{plddt:.1f}" if isinstance(plddt, float) else plddt
        html_parts.append(
            f'<tr><td><b>{sp}</b></td><td><i>{s["common_name"]}</i></td>'
            f'<td>{s["length"]}</td><td>{s["protein_length"]}</td>'
            f'<td>{s["gc_content"]:.1%}</td><td>{s["log_likelihood"]:.2f}</td>'
            f'<td>{plddt_str}</td></tr>'
        )
    html_parts.append("</table>")

    # GC content comparison
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_bar_chart(
        species,
        {"GC Content": [scores[s]["gc_content"] for s in species]},
        title="GC Content Across Species",
        value_format=".2f",
    ))
    html_parts.append("</div>")

    # DNA phylogenetic tree
    html_parts.append("<h3>Phylogenetic Relationships</h3>")
    html_parts.append(
        '<div class="note">'
        "Trees built from pairwise sequence identity using UPGMA clustering. "
        "Species that diverged more recently cluster together. DNA and protein trees "
        "may differ when synonymous mutations dominate."
        "</div>"
    )

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species, dna_matrix,
        title=f"{gene_name} - DNA Phylogenetic Tree",
    ))
    html_parts.append("</div>")

    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_dendrogram(
        species, protein_matrix,
        title=f"{gene_name} - Protein Phylogenetic Tree",
        color="#059669",
    ))
    html_parts.append("</div>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    summary = {
        "gene_name": gene_name,
        "n_species": n,
        "avg_dna_identity": round(avg_dna_id, 4),
        "avg_protein_identity": round(avg_protein_id, 4),
        "avg_plddt": round(avg_plddt, 1),
        "n_structures": len(structures),
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Pipeline orchestrator
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    gene_set: str = "insulin",
    model_name: str = "HuggingFaceBio/Carbon-3B",
    custom_json: str = "",
) -> tuple[str, str]:
    """
    End-to-end cross-species gene comparison pipeline.

    Returns (comparison JSON, structures JSON).

    1. Load homologous gene sequences across species
    2. Score with Carbon genomic language model
    3. Align sequences and compute pairwise similarity
    4. Fold translated proteins with ESMFold
    5. Generate comprehensive summary with phylogenetic trees
    """
    log.info(f"Starting cross-species gene comparison pipeline (gene_set={gene_set})...")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = [
            "Load Genes",
            "Carbon Scoring",
            "Sequence Alignment",
            "ESMFold Structures",
            "Generate Summary",
        ]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#2563eb;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#2563eb;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>Cross-Species Gene Comparison</h2>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    # Stage 1
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Loading homologous gene sequences...")),
        do_flush=True,
    )
    genes_dir = await load_genes(gene_set=gene_set, custom_json=custom_json)

    # Stage 2
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Scoring sequences with Carbon...")),
        do_flush=True,
    )
    scores_json = await score_sequences(genes_dir=genes_dir, model_name=model_name)

    # Stage 3
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Aligning sequences with Needleman-Wunsch...")),
        do_flush=True,
    )
    comparison_json = await align_and_compare(scores_json=scores_json, genes_dir=genes_dir)

    # Stage 4
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Folding proteins with ESMFold...")),
        do_flush=True,
    )
    structures_json = await fold_proteins(comparison_json=comparison_json)

    # Stage 5
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(5, "Generating summary report...")),
        do_flush=True,
    )
    summary_json = await generate_summary(
        comparison_json=comparison_json,
        structures_json=structures_json,
    )

    # Final report
    summary = json.loads(summary_json)
    comparison = json.loads(comparison_json)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{summary['gene_name']}</div><div class="label">Gene</div></div>
      <div class="stat"><div class="value">{summary['n_species']}</div><div class="label">Species</div></div>
      <div class="stat"><div class="value">{summary['avg_dna_identity']:.0%}</div><div class="label">Avg DNA Identity</div></div>
      <div class="stat"><div class="value">{summary['avg_protein_identity']:.0%}</div><div class="label">Avg Protein Identity</div></div>
      <div class="stat"><div class="value">{summary['avg_plddt']:.1f}</div><div class="label">Avg pLDDT</div></div>
      <div class="stat"><div class="value">{summary['n_structures']}</div><div class="label">3D Structures</div></div>
    </div>
    <div class="card">
      <b>Gene:</b> {summary['gene_name']} |
      <b>Species:</b> {', '.join(comparison['species'])} |
      <b>Model:</b> {model_name}
    </div>
    <div class="note">
      All 4 pipeline stages completed. View individual task reports for DNA/protein
      identity heatmaps, phylogenetic trees, interactive 3D protein structures with
      pLDDT confidence, Carbon log-likelihood scores, and evolutionary analysis.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)
    log.info("Pipeline complete.")
    return comparison_json, structures_json

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_gene_comparison/genomic_gene_comparison.py*

## Run the workflow

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_gene_comparison):

```
cd v2/tutorials/genomic_gene_comparison
uv run --script genomic_gene_comparison.py
```

Or submit a specific gene set with the Flyte CLI:

```
flyte run genomic_gene_comparison.py pipeline --gene_set hemoglobin
```

This example needs a GPU for Carbon and ESMFold. Open the run URL and check each task's report tab for heatmaps, dendrograms, and interactive 3D viewers.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/genomic-variant-effect ===

# Genomic variant effect prediction

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_variant_effect).

This tutorial demonstrates zero-shot variant effect prediction (VEP) with HuggingFace [Carbon](https://huggingface.co/HuggingFaceBio/Carbon-3B). The pipeline loads clinically relevant variants across genes such as BRCA2, TP53, CFTR, KRAS, and HBB, scores each mutation with a log-likelihood ratio, and produces rich HTML reports with DNA tracks, lollipop plots, confusion matrices, and ranked pathogenicity tables.

Flyte provides:

- **GPU-backed inference** for Carbon scoring with live progress reports.
- **CPU analysis tasks** for visualization and accuracy metrics without holding a GPU.
- **End-to-end orchestration** from variant loading through summary reporting.

## Define the task environments

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="genomic-variant-effect", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="genomic-variant-effect-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="genomic-variant-effect-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default gene variants — clinically relevant mutations
# ------------------------------------------------------------------
# Each entry: gene name -> { "sequence": reference DNA, "variants": [{ "pos": 0-indexed, "ref": base, "alt": base, "name": "...", "known_effect": "..." }] }
# Sequences are short windows (~120-200bp) around the variant site for tractable inference.

DEFAULT_GENE_VARIANTS = {
    "BRCA2 (Breast Cancer)": {
        "description": "Tumor suppressor critical for DNA repair via homologous recombination. Mutations dramatically increase breast and ovarian cancer risk.",
        "sequence": "ATGGCCTCGAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAG",
        "variants": [
            {"pos": 12, "ref": "A", "alt": "T", "name": "c.37A>T", "known_effect": "pathogenic", "clinical": "Nonsense mutation — truncates protein early"},
            {"pos": 18, "ref": "G", "alt": "A", "name": "c.55G>A", "known_effect": "benign", "clinical": "Synonymous — no amino acid change"},
            {"pos": 30, "ref": "C", "alt": "T", "name": "c.91C>T", "known_effect": "pathogenic", "clinical": "Missense in DNA-binding domain"},
            {"pos": 45, "ref": "G", "alt": "C", "name": "c.136G>C", "known_effect": "uncertain", "clinical": "Variant of uncertain significance (VUS)"},
        ],
    },
    "TP53 (Tumor Suppressor)": {
        "description": "Guardian of the genome. Activates DNA repair, cell cycle arrest, and apoptosis. Mutated in >50% of human cancers.",
        "sequence": "ATGGAGGAGCCGCAGTCAGATCCTAGCGTGAGTTTGCACCCTTCAGAGACAGAAACCACTGGATTGGAGACTACTTCCTGAAACAACGTTCTGTCCCCCTTGCCGTCCCAAGCAATGGATGAT",
        "variants": [
            {"pos": 15, "ref": "C", "alt": "T", "name": "R175H", "known_effect": "pathogenic", "clinical": "Hotspot — gain-of-function, dominant negative. Most common TP53 mutation in cancer"},
            {"pos": 36, "ref": "T", "alt": "C", "name": "P72R", "known_effect": "benign", "clinical": "Common polymorphism — subtle effect on apoptosis efficiency"},
            {"pos": 54, "ref": "C", "alt": "A", "name": "G245S", "known_effect": "pathogenic", "clinical": "Contact mutant — disrupts DNA binding"},
            {"pos": 72, "ref": "T", "alt": "G", "name": "R248W", "known_effect": "pathogenic", "clinical": "Structural mutant — destabilizes DNA-binding loop"},
            {"pos": 90, "ref": "C", "alt": "T", "name": "R273H", "known_effect": "pathogenic", "clinical": "Contact mutant — directly contacts DNA bases"},
        ],
    },
    "CFTR (Cystic Fibrosis)": {
        "description": "Chloride channel protein. Mutations cause cystic fibrosis — the most common lethal genetic disease in people of European descent.",
        "sequence": "ATGCAGAGGTCGCCTCTGGAAAAGGCCAGCGTTGTCTCCAAACTTTTTTTCAGCTGGACCAGACCAATTTTGAGGAAAGGATACAGACAGCGCCTGGAATTGTCAGACATATACCAAATCCCTTC",
        "variants": [
            {"pos": 9, "ref": "G", "alt": "A", "name": "G85E", "known_effect": "pathogenic", "clinical": "Disrupts chloride channel processing"},
            {"pos": 24, "ref": "C", "alt": "T", "name": "R117H", "known_effect": "pathogenic", "clinical": "Reduces channel conductance — milder CF phenotype"},
            {"pos": 48, "ref": "T", "alt": "C", "name": "I148T", "known_effect": "benign", "clinical": "Previously misclassified — now known benign polymorphism"},
            {"pos": 66, "ref": "A", "alt": "G", "name": "R334W", "known_effect": "pathogenic", "clinical": "Gating mutation — channel opens less frequently"},
        ],
    },
    "KRAS (Oncogene)": {
        "description": "GTPase signal switch. KRAS mutations are the most common oncogenic driver — found in ~25% of all human cancers, especially pancreatic, colorectal, and lung.",
        "sequence": "ATGACTGAATATAAACTTGTGGTAGTTGGAGCTGGTGGCGTAGGCAAGAGTGCCTTGACGATACAGCTAATTCAGAATCATTTTGTGGACGAATATGATCCAACAATAGAGGATTCCTACAGGAA",
        "variants": [
            {"pos": 34, "ref": "G", "alt": "T", "name": "G12V", "known_effect": "pathogenic", "clinical": "Locks KRAS in active state — constitutive proliferation signal"},
            {"pos": 35, "ref": "G", "alt": "A", "name": "G12D", "known_effect": "pathogenic", "clinical": "Most common KRAS mutation in pancreatic cancer"},
            {"pos": 37, "ref": "G", "alt": "T", "name": "G13D", "known_effect": "pathogenic", "clinical": "Constitutively active — common in colorectal cancer"},
            {"pos": 60, "ref": "C", "alt": "A", "name": "Q61K", "known_effect": "pathogenic", "clinical": "Impairs GTP hydrolysis — locked ON state"},
        ],
    },
    "HBB (Sickle Cell)": {
        "description": "Beta-globin subunit of hemoglobin. The sickle cell mutation (E6V) is the most well-known single-base disease variant in humans.",
        "sequence": "ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTGTGGGGCAAGGTGAACGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTGGTCTACCCTTGGACCCAGAGG",
        "variants": [
            {"pos": 17, "ref": "A", "alt": "T", "name": "E6V (HbS)", "known_effect": "pathogenic", "clinical": "THE sickle cell mutation — causes hemoglobin polymerization under low O2"},
            {"pos": 19, "ref": "G", "alt": "A", "name": "E6K (HbC)", "known_effect": "pathogenic", "clinical": "Hemoglobin C disease — milder than sickle cell but causes crystal formation"},
            {"pos": 36, "ref": "G", "alt": "A", "name": "E26K", "known_effect": "benign", "clinical": "Hemoglobin E — most common Hb variant worldwide, mild effect"},
            {"pos": 78, "ref": "C", "alt": "T", "name": "Q39X", "known_effect": "pathogenic", "clinical": "Nonsense — causes beta-thalassemia (no functional beta-globin)"},
        ],
    },
}

# DNA base colors (classic genomics color scheme)
BASE_COLORS = {"A": "#2ecc71", "T": "#e74c3c", "G": "#f39c12", "C": "#3498db"}
BASE_COMPLEMENT = {"A": "T", "T": "A", "G": "C", "C": "G"}

# Pathogenicity color scheme
EFFECT_COLORS = {
    "pathogenic": "#dc2626",
    "benign": "#059669",
    "uncertain": "#f59e0b",
}
EFFECT_BADGES = {
    "pathogenic": "badge-danger",
    "benign": "badge-success",
    "uncertain": "badge-warning",
}

# ------------------------------------------------------------------
# Report styling — genomics-themed deep blues and teals
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #1e3a5f; border-bottom: 2px solid #2563eb; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #1e40af; margin-top: 20px; }
  .report .card { background: #eff6ff; border: 1px solid #bfdbfe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #dbeafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #1e3a5f; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #1e3a5f; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dbeafe; }
  .report tr:nth-child(even) { background: #eff6ff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-warning { background: #fef3c7; color: #92400e; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #dbeafe; color: #1e40af; }
  .report .chart-container { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #eff6ff; border-left: 4px solid #2563eb; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .gene-card { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .dna-track { font-family: 'SF Mono', 'Fira Code', monospace; letter-spacing: 1px; }
</style>
"""

def _wrap_report(html: str) -> str:
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    value_format: str = ".2f",
) -> str:
    """Generate an SVG grouped bar chart."""
    if not labels:
        return ""

    default_colors = ["#2563eb", "#1e3a5f", "#3b82f6", "#60a5fa", "#93c5fd"]
    colors = colors or default_colors

    ml, mr, mt, mb = 70, 20, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(abs(v) for v in all_vals) if all_vals else 1
    y_min = min(all_vals) if all_vals else 0
    # For VEP scores (negative = more damaging), we need to handle negative values
    if y_min >= 0:
        y_min_plot = 0
        y_max_plot = y_max * 1.15 or 1
    else:
        y_max_plot = max(y_max * 1.15, 0.1)
        y_min_plot = y_min * 1.15

    y_range = y_max_plot - y_min_plot or 1

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v - y_min_plot) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_min_plot + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_format}}</text>'
        )

    # Zero line
    if y_min_plot < 0 < y_max_plot:
        zy = sy(0)
        svg.append(
            f'<line x1="{ml}" y1="{zy:.1f}" x2="{ml + cw}" y2="{zy:.1f}" '
            f'stroke="#374151" stroke-width="1.5"/>'
        )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            if val >= 0:
                by = sy(val)
                bh = sy(0) - by if y_min_plot < 0 else mt + ch - by
            else:
                by = sy(0) if y_min_plot < 0 else mt + ch
                bh = sy(val) - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{max(0, bh):.1f}" fill="{color}" rx="2"/>'
            )
            text_y = by - 4 if val >= 0 else by + bh + 12
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{text_y:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_format}}</text>'
            )
        # Rotated group label
        lx = gx + n_series * bar_width / 2
        svg.append(
            f'<text x="{lx:.1f}" y="{mt + ch + 14}" '
            f'text-anchor="end" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {lx:.1f}, {mt + ch + 14})">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 110
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 110}" y="{mt + ch + 55}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 110 + 16}" y="{mt + ch + 66}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    width: int = 700,
    height: int = 500,
    value_format: str = ".2f",
    diverging: bool = False,
) -> str:
    """Generate an SVG heatmap. If diverging=True, uses red-white-blue scale centered at 0."""
    n_rows = len(matrix)
    n_cols = len(matrix[0]) if matrix else 0
    if not n_rows or not n_cols:
        return ""

    show_values = n_rows <= 10 and n_cols <= 12

    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)

    if diverging:
        abs_max = max(abs(v_min), abs(v_max)) or 1

        def get_color(v):
            t = v / abs_max  # -1 to 1
            if t < 0:
                # White to red (negative = damaging)
                r = 255
                g = int(255 * (1 + t))
                b = int(255 * (1 + t))
            else:
                # White to blue (positive = benign)
                r = int(255 * (1 - t))
                g = int(255 * (1 - t))
                b = 255
            return f"rgb({r},{g},{b})"
    else:
        v_range = v_max - v_min or 1

        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 30))
            g = int(255 - t * (255 - 58))
            b = int(255 - t * (255 - 95))
            return f"rgb({r},{g},{b})"

    # Layout
    ml = max(140, max(len(l) for l in row_labels) * 7 + 20) if row_labels else 140
    mr = 20
    mt = 80 if col_labels else 40
    mb = 30
    cw = width - ml - mr
    ch = height - mt - mb

    cell_w = cw / n_cols
    cell_h = ch / n_rows

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Column labels (rotated)
    for j, label in enumerate(col_labels):
        cx = ml + j * cell_w + cell_w / 2
        svg.append(
            f'<text x="{cx:.1f}" y="{mt - 8}" text-anchor="end" '
            f'font-size="10" fill="#374151" '
            f'transform="rotate(-45, {cx:.1f}, {mt - 8})">{label}</text>'
        )

    # Row labels + cells
    for i, row_label in enumerate(row_labels):
        ry = mt + i * cell_h + cell_h / 2
        svg.append(
            f'<text x="{ml - 8}" y="{ry + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#374151">{row_label}</text>'
        )
        for j in range(n_cols):
            val = matrix[i][j]
            color = get_color(val)
            cx = ml + j * cell_w
            cy = mt + i * cell_h
            svg.append(
                f'<rect x="{cx:.1f}" y="{cy:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{color}" stroke="#fff" stroke-width="1"/>'
            )
            if show_values:
                if diverging:
                    t = abs(val) / (max(abs(v_min), abs(v_max)) or 1)
                else:
                    t = (val - v_min) / (v_max - v_min or 1)
                text_color = "#fff" if t > 0.55 else "#1a1a2e"
                font_size = min(10, int(cell_w / 4), int(cell_h / 2.5))
                font_size = max(7, font_size)
                svg.append(
                    f'<text x="{cx + cell_w / 2:.1f}" y="{cy + cell_h / 2 + 3:.1f}" '
                    f'text-anchor="middle" font-size="{font_size}" '
                    f'fill="{text_color}">{val:{value_format}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_dna_track(
    sequence: str,
    variants: list[dict],
    gene_name: str = "",
    width: int = 900,
) -> str:
    """Render a color-coded DNA sequence track with variant positions highlighted."""
    chars_per_line = 60
    char_w = 11
    line_h = 22
    label_w = 50
    n_lines = (len(sequence) + chars_per_line - 1) // chars_per_line

    # Extra space for variant annotations
    variant_positions = {v["pos"] for v in variants}
    svg_h = n_lines * line_h + 60

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {svg_h}" '
        f'style="width:100%;max-width:{width}px;height:auto;font-family:monospace;">',
        f'<rect width="{width}" height="{svg_h}" fill="#fff" rx="6"/>',
    ]

    if gene_name:
        svg.append(
            f'<text x="{width / 2}" y="16" text-anchor="middle" '
            f'font-size="12" font-weight="600" fill="#1e3a5f">{gene_name}</text>'
        )

    y_offset = 28

    for line_idx in range(n_lines):
        start = line_idx * chars_per_line
        end = min(start + chars_per_line, len(sequence))
        chunk = sequence[start:end]
        y = y_offset + line_idx * line_h

        # Position label
        svg.append(
            f'<text x="{label_w - 8}" y="{y}" text-anchor="end" '
            f'font-size="9" fill="#9ca3af">{start + 1}</text>'
        )

        for ci, base in enumerate(chunk):
            abs_pos = start + ci
            x = label_w + ci * char_w
            color = BASE_COLORS.get(base, "#6b7280")
            is_variant = abs_pos in variant_positions

            if is_variant:
                # Red highlight for variant positions
                svg.append(
                    f'<rect x="{x - 1}" y="{y - 13}" width="{char_w + 2}" height="{line_h}" '
                    f'fill="#fee2e2" stroke="#dc2626" stroke-width="1.5" rx="2"/>'
                )
                # Small triangle marker above
                svg.append(
                    f'<polygon points="{x + char_w / 2:.1f},{y - 16} {x + char_w / 2 - 3:.1f},{y - 20} {x + char_w / 2 + 3:.1f},{y - 20}" '
                    f'fill="#dc2626"/>'
                )
            else:
                # Subtle background
                svg.append(
                    f'<rect x="{x}" y="{y - 12}" width="{char_w}" height="{line_h - 2}" '
                    f'fill="{color}" opacity="0.08" rx="1"/>'
                )

            svg.append(
                f'<text x="{x + char_w / 2:.1f}" y="{y}" text-anchor="middle" '
                f'font-size="11" font-weight="{"700" if is_variant else "500"}" '
                f'fill="{color}">{base}</text>'
            )

            # Spacer every 10 bases
            if (abs_pos + 1) % 10 == 0 and ci < len(chunk) - 1:
                svg.append(
                    f'<line x1="{x + char_w + 1}" y1="{y - 11}" '
                    f'x2="{x + char_w + 1}" y2="{y + 4}" '
                    f'stroke="#e5e7eb" stroke-width="0.5"/>'
                )

    # Legend
    ly = svg_h - 12
    lx = label_w
    for base, color in BASE_COLORS.items():
        svg.append(f'<rect x="{lx}" y="{ly - 8}" width="8" height="8" rx="1" fill="{color}"/>')
        svg.append(f'<text x="{lx + 12}" y="{ly}" font-size="9" fill="#6b7280">{base}</text>')
        lx += 30
    lx += 10
    svg.append(f'<rect x="{lx}" y="{ly - 8}" width="8" height="8" rx="1" fill="#fee2e2" stroke="#dc2626" stroke-width="1"/>')
    svg.append(f'<text x="{lx + 12}" y="{ly}" font-size="9" fill="#6b7280">Variant site</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_lollipop_plot(
    variants: list[dict],
    seq_length: int,
    title: str = "",
    width: int = 700,
    height: int = 200,
) -> str:
    """Generate a lollipop plot showing variant positions along a gene with effect scores."""
    if not variants:
        return ""

    ml, mr, mt, mb = 50, 30, 40, 40
    cw = width - ml - mr
    ch = height - mt - mb

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Gene track (horizontal bar)
    track_y = mt + ch * 0.7
    svg.append(
        f'<rect x="{ml}" y="{track_y - 4}" width="{cw}" height="8" '
        f'fill="#dbeafe" rx="4" stroke="#93c5fd" stroke-width="1"/>'
    )

    # Position labels at ends
    svg.append(
        f'<text x="{ml}" y="{track_y + 20}" text-anchor="middle" '
        f'font-size="9" fill="#9ca3af">1</text>'
    )
    svg.append(
        f'<text x="{ml + cw}" y="{track_y + 20}" text-anchor="middle" '
        f'font-size="9" fill="#9ca3af">{seq_length}</text>'
    )

    # Lollipops
    scores = [v.get("score", 0) for v in variants]
    max_score = max(abs(s) for s in scores) if scores else 1

    for v in variants:
        pos = v["pos"]
        score = v.get("score", 0)
        effect = v.get("known_effect", "uncertain")
        color = EFFECT_COLORS.get(effect, "#6b7280")

        x = ml + (pos / max(seq_length - 1, 1)) * cw
        stem_h = max(20, abs(score) / max_score * (ch * 0.6))
        circle_y = track_y - stem_h - 6

        # Stem
        svg.append(
            f'<line x1="{x:.1f}" y1="{track_y - 4}" x2="{x:.1f}" y2="{circle_y + 5:.1f}" '
            f'stroke="{color}" stroke-width="2"/>'
        )
        # Circle
        svg.append(
            f'<circle cx="{x:.1f}" cy="{circle_y:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5"/>'
        )
        # Label
        svg.append(
            f'<text x="{x:.1f}" y="{circle_y - 9:.1f}" text-anchor="middle" '
            f'font-size="8" fill="#374151" font-weight="600">{v.get("name", "")}</text>'
        )

    # Legend
    lx = ml
    for effect, color in EFFECT_COLORS.items():
        svg.append(f'<circle cx="{lx + 4}" cy="{height - 10}" r="4" fill="{color}"/>')
        svg.append(
            f'<text x="{lx + 12}" y="{height - 6}" font-size="9" fill="#374151">'
            f'{effect.title()}</text>'
        )
        lx += len(effect) * 7 + 24

    svg.append("</svg>")
    return "\n".join(svg)

def _make_score_comparison_chart(
    gene_results: dict,
    width: int = 800,
    height: int = 350,
) -> str:
    """Generate a dot plot comparing VEP scores across all genes, colored by known effect."""
    all_variants = []
    for gene_name, gene_data in gene_results.items():
        for v in gene_data["variants"]:
            all_variants.append({
                "gene": gene_name.split("(")[0].strip(),
                "name": v["name"],
                "score": v.get("score", 0),
                "known_effect": v["known_effect"],
            })

    if not all_variants:
        return ""

    ml, mr, mt, mb = 120, 30, 40, 30
    cw = width - ml - mr
    ch = height - mt - mb

    scores = [v["score"] for v in all_variants]
    s_min, s_max = min(scores), max(scores)
    s_pad = (s_max - s_min) * 0.1 or 0.5
    s_min -= s_pad
    s_max += s_pad
    s_range = s_max - s_min or 1

    def sx(s):
        return ml + (s - s_min) / s_range * cw

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
        f'<text x="{width / 2}" y="22" text-anchor="middle" '
        f'font-size="14" font-weight="600" fill="#1a1a2e">'
        f'Variant Effect Scores — All Genes</text>',
    ]

    # Grid
    for i in range(6):
        gx = s_min + s_range * i / 5
        px = sx(gx)
        svg.append(
            f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
            f'stroke="#f3f4f6" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 16}" text-anchor="middle" '
            f'font-size="10" fill="#9ca3af">{gx:.2f}</text>'
        )

    # Zero line
    if s_min < 0 < s_max:
        zx = sx(0)
        svg.append(
            f'<line x1="{zx:.1f}" y1="{mt}" x2="{zx:.1f}" y2="{mt + ch}" '
            f'stroke="#374151" stroke-width="1" stroke-dasharray="4,3"/>'
        )

    # Rows per variant
    row_h = ch / max(len(all_variants), 1)
    for i, v in enumerate(all_variants):
        y = mt + i * row_h + row_h / 2
        color = EFFECT_COLORS.get(v["known_effect"], "#6b7280")
        px = sx(v["score"])

        # Label
        label = f'{v["gene"]} {v["name"]}'
        svg.append(
            f'<text x="{ml - 8}" y="{y + 3:.1f}" text-anchor="end" '
            f'font-size="9" fill="#374151">{label}</text>'
        )
        # Connector line
        svg.append(
            f'<line x1="{ml}" y1="{y:.1f}" x2="{px:.1f}" y2="{y:.1f}" '
            f'stroke="{color}" stroke-width="1" opacity="0.3"/>'
        )
        # Dot
        svg.append(
            f'<circle cx="{px:.1f}" cy="{y:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1"/>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate gene variants
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def load_variants(
    variants_json: str = "",
) -> flyte.io.Dir:
    """Load gene variant definitions, validate sequences, and save to a temp directory."""
    if variants_json:
        genes = json.loads(variants_json)
    else:
        genes = DEFAULT_GENE_VARIANTS

    # Validate
    valid_bases = set("ATGC")
    for gene_name, gene_data in genes.items():
        seq = gene_data["sequence"].upper()
        invalid = set(seq) - valid_bases
        if invalid:
            log.warning(f"{gene_name}: invalid bases {invalid} — removing them")
            seq = "".join(b for b in seq if b in valid_bases)
            gene_data["sequence"] = seq

        for v in gene_data["variants"]:
            pos = v["pos"]
            if pos < 0 or pos >= len(seq):
                log.warning(f"{gene_name} variant {v['name']}: position {pos} out of range [0, {len(seq)})")
            elif seq[pos] != v["ref"]:
                log.warning(f"{gene_name} variant {v['name']}: expected ref={v['ref']} at pos {pos}, found {seq[pos]}")

    total_variants = sum(len(g["variants"]) for g in genes.values())
    log.info(f"Loaded {len(genes)} genes with {total_variants} variants")

    out_dir = tempfile.mkdtemp(prefix="genomic_vep_")
    with open(os.path.join(out_dir, "genes.json"), "w") as f:
        json.dump(genes, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Run Carbon model for variant effect scoring
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def score_variants(
    variants_dir: flyte.io.Dir,
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> str:
    """Score each variant using Carbon's log-likelihood ratio.

    For each variant, we compute:
        score = log P(alt_sequence) - log P(ref_sequence)

    A negative score means the model considers the variant less likely than
    the reference — suggestive of a damaging/pathogenic effect. A score near
    zero means the model sees little difference (likely benign).
    """
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info(f"Loading Carbon model: {model_name}")

    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu":
        log.warning("Running on CPU — inference will be slow. GPU recommended for production.")

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    ).to(device)
    model.eval()

    # Load variants
    variants_path = await variants_dir.download()
    with open(os.path.join(variants_path, "genes.json")) as f:
        genes = json.load(f)

    results = {}
    total_variants = sum(len(g["variants"]) for g in genes.values())
    scored = 0

    progress_html = """
    <h2>Carbon Variant Effect Scoring</h2>
    <div class="card">
        <b>Model:</b> {model}<br>
        <b>Device:</b> {device}<br>
        <b>Progress:</b> {scored}/{total} variants scored
    </div>
    """

    for gene_name, gene_data in genes.items():
        ref_seq = gene_data["sequence"]
        gene_results = {
            "description": gene_data.get("description", ""),
            "sequence": ref_seq,
            "variants": [],
        }

        # Score reference sequence
        ref_prompt = f"<dna>{ref_seq}"
        ref_inputs = tokenizer(ref_prompt, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            ref_output = model(**ref_inputs, labels=ref_inputs["input_ids"])
            ref_loss = ref_output.loss.item()
            ref_ll = -ref_loss * ref_inputs["input_ids"].shape[1]

        for variant in gene_data["variants"]:
            scored += 1
            await flyte.report.replace.aio(
                _wrap_report(progress_html.format(
                    model=model_name, device=device, scored=scored, total=total_variants
                )),
                do_flush=True,
            )

            # Create mutant sequence
            pos = variant["pos"]
            alt_seq = ref_seq[:pos] + variant["alt"] + ref_seq[pos + 1:]

            # Score mutant
            alt_prompt = f"<dna>{alt_seq}"
            alt_inputs = tokenizer(alt_prompt, return_tensors="pt", add_special_tokens=False).to(device)

            with torch.no_grad():
                alt_output = model(**alt_inputs, labels=alt_inputs["input_ids"])
                alt_loss = alt_output.loss.item()
                alt_ll = -alt_loss * alt_inputs["input_ids"].shape[1]

            # VEP score: positive = model prefers alt (likely benign), negative = model prefers ref (likely pathogenic)
            vep_score = alt_ll - ref_ll

            gene_results["variants"].append({
                **variant,
                "score": round(vep_score, 4),
                "ref_ll": round(ref_ll, 4),
                "alt_ll": round(alt_ll, 4),
            })

            log.info(
                f"  {gene_name} | {variant['name']}: score={vep_score:.4f} "
                f"(known: {variant['known_effect']})"
            )

        results[gene_name] = gene_results

    # Generate scoring report
    html_parts = [
        "<h2>Carbon Variant Effect Scoring</h2>",
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{len(genes)}</div><div class="label">Genes</div></div>',
        f'<div class="stat"><div class="value">{total_variants}</div><div class="label">Variants Scored</div></div>',
        f'<div class="stat"><div class="value">{model_name.split("/")[-1]}</div><div class="label">Model</div></div>',
        f'<div class="stat"><div class="value">{device.upper()}</div><div class="label">Device</div></div>',
        "</div>",
    ]

    # Per-gene tables
    for gene_name, gene_data in results.items():
        html_parts.append(f'<h3>{gene_name}</h3>')
        html_parts.append(f'<div class="note">{gene_data["description"]}</div>')
        html_parts.append("<table><tr><th>Variant</th><th>Ref</th><th>Alt</th><th>VEP Score</th><th>Known Effect</th><th>Clinical</th></tr>")

        for v in gene_data["variants"]:
            badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
            direction = "damaging" if v["score"] < -0.1 else "neutral" if abs(v["score"]) <= 0.1 else "tolerated"
            html_parts.append(
                f'<tr>'
                f'<td><b>{v["name"]}</b></td>'
                f'<td style="color:{BASE_COLORS.get(v["ref"], "#333")};font-weight:700">{v["ref"]}</td>'
                f'<td style="color:{BASE_COLORS.get(v["alt"], "#333")};font-weight:700">{v["alt"]}</td>'
                f'<td><b>{v["score"]:.4f}</b> <span style="font-size:0.8em;color:#6c757d">({direction})</span></td>'
                f'<td><span class="badge {badge}">{v["known_effect"]}</span></td>'
                f'<td style="font-size:0.85em">{v.get("clinical", "")}</td>'
                f'</tr>'
            )
        html_parts.append("</table>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    return json.dumps(results)

# ------------------------------------------------------------------
# Task 3: Analyze and visualize variant effects
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def analyze_effects(
    scores_json: str,
    variants_dir: flyte.io.Dir,
) -> str:
    """Analyze VEP scores: classification accuracy, gene-level summaries, and rich visualizations."""
    results = json.loads(scores_json)

    variants_path = await variants_dir.download()
    with open(os.path.join(variants_path, "genes.json")) as f:
        genes = json.load(f)

    html_parts = ["<h2>Variant Effect Analysis</h2>"]

    # ------------------------------------------------------------------
    # Overall accuracy: does the model's score direction match known labels?
    # ------------------------------------------------------------------
    all_variants = []
    correct = 0
    total_known = 0
    true_pos = 0
    false_pos = 0
    true_neg = 0
    false_neg = 0

    for gene_name, gene_data in results.items():
        for v in gene_data["variants"]:
            all_variants.append({**v, "gene": gene_name})
            if v["known_effect"] in ("pathogenic", "benign"):
                total_known += 1
                predicted_pathogenic = v["score"] < -0.05
                actual_pathogenic = v["known_effect"] == "pathogenic"
                if predicted_pathogenic == actual_pathogenic:
                    correct += 1
                if predicted_pathogenic and actual_pathogenic:
                    true_pos += 1
                elif predicted_pathogenic and not actual_pathogenic:
                    false_pos += 1
                elif not predicted_pathogenic and actual_pathogenic:
                    false_neg += 1
                else:
                    true_neg += 1

    accuracy = correct / total_known if total_known else 0
    precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) else 0
    recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) else 0

    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{accuracy:.0%}</div><div class="label">Direction Accuracy</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{precision:.0%}</div><div class="label">Precision (Pathogenic)</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{recall:.0%}</div><div class="label">Recall (Pathogenic)</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{len(all_variants)}</div><div class="label">Total Variants</div></div>')
    html_parts.append("</div>")

    html_parts.append(
        '<div class="note">'
        "<b>How to read VEP scores:</b> Negative scores mean Carbon considers the variant "
        "less likely than the reference sequence — suggestive of a damaging effect. Scores near "
        "zero indicate the model sees little difference (likely benign). The magnitude indicates "
        "confidence."
        "</div>"
    )

    # ------------------------------------------------------------------
    # Cross-gene score comparison dot plot
    # ------------------------------------------------------------------
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_score_comparison_chart(results))
    html_parts.append("</div>")

    # ------------------------------------------------------------------
    # Per-gene visualizations
    # ------------------------------------------------------------------
    for gene_name, gene_data in results.items():
        short_name = gene_name.split("(")[0].strip()
        html_parts.append(f'<h3>{gene_name}</h3>')
        html_parts.append(f'<div class="note">{gene_data["description"]}</div>')

        # DNA track with variant positions highlighted
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_dna_track(
            gene_data["sequence"],
            gene_data["variants"],
            gene_name=f"{short_name} Reference Sequence",
        ))
        html_parts.append("</div>")

        # Lollipop plot
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_lollipop_plot(
            gene_data["variants"],
            len(gene_data["sequence"]),
            title=f"{short_name} — Variant Positions & Effect Scores",
        ))
        html_parts.append("</div>")

        # Score bar chart for this gene
        variant_names = [v["name"] for v in gene_data["variants"]]
        variant_scores = [v["score"] for v in gene_data["variants"]]
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_bar_chart(
            variant_names,
            {"VEP Score": variant_scores},
            title=f"{short_name} — Log-Likelihood Ratio Scores",
            colors=[EFFECT_COLORS.get(v["known_effect"], "#6b7280") for v in gene_data["variants"]],
            value_format=".3f",
        ))
        html_parts.append("</div>")

        # Variant detail cards
        for v in gene_data["variants"]:
            badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
            score_color = "#dc2626" if v["score"] < -0.1 else "#059669" if v["score"] > 0.05 else "#f59e0b"
            html_parts.append(
                f'<div class="gene-card">'
                f'<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:8px;">'
                f'<b style="font-size:1.1em">{v["name"]}</b>'
                f'<span class="badge {badge}">{v["known_effect"]}</span>'
                f'</div>'
                f'<div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:8px;">'
                f'<div><span style="color:#6c757d;font-size:0.85em">Ref base:</span> '
                f'<b style="color:{BASE_COLORS.get(v["ref"], "#333")}">{v["ref"]}</b></div>'
                f'<div><span style="color:#6c757d;font-size:0.85em">Alt base:</span> '
                f'<b style="color:{BASE_COLORS.get(v["alt"], "#333")}">{v["alt"]}</b></div>'
                f'<div><span style="color:#6c757d;font-size:0.85em">VEP Score:</span> '
                f'<b style="color:{score_color}">{v["score"]:.4f}</b></div>'
                f'</div>'
                f'<div style="margin-top:8px;font-size:0.9em;color:#374151">{v.get("clinical", "")}</div>'
                f'</div>'
            )

    # ------------------------------------------------------------------
    # Confusion matrix as heatmap
    # ------------------------------------------------------------------
    html_parts.append("<h3>Classification Performance</h3>")
    html_parts.append(
        '<div class="note">'
        "Using a simple threshold (score &lt; -0.05 = predicted pathogenic). "
        "This is zero-shot — no training on these specific variants."
        "</div>"
    )

    conf_matrix = [[true_pos, false_neg], [false_pos, true_neg]]
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        conf_matrix,
        ["Actual Pathogenic", "Actual Benign"],
        ["Predicted Pathogenic", "Predicted Benign"],
        title="Confusion Matrix (Known Variants Only)",
        value_format=".0f",
        width=400,
        height=300,
    ))
    html_parts.append("</div>")

    # ------------------------------------------------------------------
    # Score distribution by known effect
    # ------------------------------------------------------------------
    html_parts.append("<h3>Score Distribution by Known Effect</h3>")

    pathogenic_scores = [v["score"] for v in all_variants if v["known_effect"] == "pathogenic"]
    benign_scores = [v["score"] for v in all_variants if v["known_effect"] == "benign"]
    uncertain_scores = [v["score"] for v in all_variants if v["known_effect"] == "uncertain"]

    stats_html = '<div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:12px;">'
    for label, scores, color in [
        ("Pathogenic", pathogenic_scores, "#dc2626"),
        ("Benign", benign_scores, "#059669"),
        ("Uncertain", uncertain_scores, "#f59e0b"),
    ]:
        if scores:
            mean_s = sum(scores) / len(scores)
            min_s = min(scores)
            max_s = max(scores)
            stats_html += (
                f'<div class="gene-card" style="border-left:4px solid {color}">'
                f'<b style="color:{color}">{label}</b> (n={len(scores)})<br>'
                f'Mean: {mean_s:.4f}<br>'
                f'Range: [{min_s:.4f}, {max_s:.4f}]'
                f'</div>'
            )
    stats_html += "</div>"
    html_parts.append(stats_html)

    # Summary
    analysis = {
        "total_variants": len(all_variants),
        "total_known": total_known,
        "accuracy": round(accuracy, 4),
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "true_pos": true_pos,
        "false_pos": false_pos,
        "true_neg": true_neg,
        "false_neg": false_neg,
        "pathogenic_mean_score": round(sum(pathogenic_scores) / len(pathogenic_scores), 4) if pathogenic_scores else None,
        "benign_mean_score": round(sum(benign_scores) / len(benign_scores), 4) if benign_scores else None,
    }

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)
    return json.dumps(analysis)

# ------------------------------------------------------------------
# Task 4: Generate comprehensive summary report
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def generate_summary(
    scores_json: str,
    analysis_json: str,
) -> str:
    """Generate the final summary report combining all results."""
    results = json.loads(scores_json)
    analysis = json.loads(analysis_json)

    html_parts = [
        "<h2>Genomic Variant Effect Prediction — Summary</h2>",
        '<div class="note">'
        "This pipeline uses <b>HuggingFace Carbon</b>, an autoregressive genomic foundation model "
        "trained on 1 trillion tokens of DNA sequence, to perform <b>zero-shot variant effect "
        "prediction</b>. No fine-tuning or labeled training data was used — the model scores "
        "variants purely based on its learned understanding of DNA sequence grammar."
        "</div>",
    ]

    # Key metrics
    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{len(results)}</div><div class="label">Genes Analyzed</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["total_variants"]}</div><div class="label">Variants Scored</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["accuracy"]:.0%}</div><div class="label">Direction Accuracy</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["precision"]:.0%}</div><div class="label">Precision</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["recall"]:.0%}</div><div class="label">Recall</div></div>')
    html_parts.append("</div>")

    # Gene summary table
    html_parts.append("<h3>Per-Gene Summary</h3>")
    html_parts.append(
        "<table><tr><th>Gene</th><th>Variants</th><th>Mean Score</th>"
        "<th>Pathogenic</th><th>Benign</th><th>Uncertain</th></tr>"
    )

    for gene_name, gene_data in results.items():
        variants = gene_data["variants"]
        scores = [v["score"] for v in variants]
        mean_score = sum(scores) / len(scores) if scores else 0
        n_path = sum(1 for v in variants if v["known_effect"] == "pathogenic")
        n_benign = sum(1 for v in variants if v["known_effect"] == "benign")
        n_unc = sum(1 for v in variants if v["known_effect"] == "uncertain")
        short = gene_name.split("(")[0].strip()

        html_parts.append(
            f"<tr><td><b>{short}</b></td><td>{len(variants)}</td>"
            f"<td>{mean_score:.4f}</td>"
            f'<td><span class="badge badge-danger">{n_path}</span></td>'
            f'<td><span class="badge badge-success">{n_benign}</span></td>'
            f'<td><span class="badge badge-warning">{n_unc}</span></td></tr>'
        )
    html_parts.append("</table>")

    # Cross-gene heatmap: gene x metric
    gene_names = [g.split("(")[0].strip() for g in results.keys()]
    metrics = ["Mean Score", "Min Score", "Max Score", "# Variants"]
    matrix = []
    for gene_data in results.values():
        scores = [v["score"] for v in gene_data["variants"]]
        matrix.append([
            sum(scores) / len(scores) if scores else 0,
            min(scores) if scores else 0,
            max(scores) if scores else 0,
            len(scores),
        ])

    html_parts.append("<h3>Gene-Level Metrics</h3>")
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        matrix,
        gene_names,
        metrics,
        title="Gene-Level VEP Score Summary",
        value_format=".2f",
        width=600,
        height=350,
    ))
    html_parts.append("</div>")

    # All variants ranked by score (most damaging first)
    html_parts.append("<h3>All Variants Ranked by Impact</h3>")
    all_vars_sorted = []
    for gene_name, gene_data in results.items():
        for v in gene_data["variants"]:
            all_vars_sorted.append({**v, "gene": gene_name.split("(")[0].strip()})
    all_vars_sorted.sort(key=lambda x: x["score"])

    html_parts.append(
        "<table><tr><th>#</th><th>Gene</th><th>Variant</th><th>Score</th>"
        "<th>Known</th><th>Clinical Significance</th></tr>"
    )
    for i, v in enumerate(all_vars_sorted):
        badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
        score_color = "#dc2626" if v["score"] < -0.1 else "#059669" if v["score"] > 0.05 else "#f59e0b"
        html_parts.append(
            f'<tr><td>{i + 1}</td><td><b>{v["gene"]}</b></td><td>{v["name"]}</td>'
            f'<td style="color:{score_color};font-weight:700">{v["score"]:.4f}</td>'
            f'<td><span class="badge {badge}">{v["known_effect"]}</span></td>'
            f'<td style="font-size:0.85em">{v.get("clinical", "")}</td></tr>'
        )
    html_parts.append("</table>")

    # Method note
    html_parts.append(
        '<div class="note">'
        "<b>Method:</b> Zero-shot variant effect prediction using log-likelihood ratio scoring. "
        "For each variant, we compute score = log P(mutant sequence | Carbon) - log P(reference sequence | Carbon). "
        "Negative scores indicate the model considers the mutant less probable than the reference, "
        "which correlates with pathogenicity. This approach requires no fine-tuning and generalizes "
        "across genes and variant types.<br><br>"
        "<b>Limitations:</b> These are short sequence windows — real clinical VEP would use longer "
        "genomic context (Carbon supports up to 786kbp). The threshold for pathogenicity classification "
        "(-0.05) is a simple heuristic; clinical use requires calibrated thresholds per gene."
        "</div>"
    )

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)
    return json.dumps({"status": "complete", "analysis": analysis})

# ------------------------------------------------------------------
# Pipeline orchestrator
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    variants_json: str = "",
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> tuple[str, str]:
    """
    End-to-end genomic variant effect prediction pipeline.

    Returns (scores JSON, analysis JSON).

    1. Load and validate gene variants
    2. Score variants with Carbon (log-likelihood ratio)
    3. Analyze effects — accuracy, visualizations, classification
    4. Generate comprehensive summary report
    """
    log.info("Starting genomic variant effect prediction pipeline...")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = [
            "Load Variants",
            "Carbon Scoring",
            "Analyze Effects",
            "Generate Summary",
        ]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#2563eb;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#2563eb;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>Genomic Variant Effect Prediction</h2>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    # Stage 1: Load variants
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Loading and validating gene variants...")),
        do_flush=True,
    )
    var_dir = await load_variants(variants_json=variants_json)

    # Stage 2: Score with Carbon
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Running Carbon model for variant effect scoring...")),
        do_flush=True,
    )
    scores_json = await score_variants(variants_dir=var_dir, model_name=model_name)

    # Stage 3: Analyze effects
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Analyzing variant effects and generating visualizations...")),
        do_flush=True,
    )
    analysis_json = await analyze_effects(scores_json=scores_json, variants_dir=var_dir)

    # Stage 4: Summary
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Generating comprehensive summary report...")),
        do_flush=True,
    )
    summary_json = await generate_summary(scores_json=scores_json, analysis_json=analysis_json)

    # Final pipeline report
    analysis = json.loads(analysis_json)
    results = json.loads(scores_json)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(results)}</div><div class="label">Genes Analyzed</div></div>
      <div class="stat"><div class="value">{analysis['total_variants']}</div><div class="label">Variants Scored</div></div>
      <div class="stat"><div class="value">{analysis['accuracy']:.0%}</div><div class="label">Direction Accuracy</div></div>
      <div class="stat"><div class="value">{analysis['precision']:.0%}</div><div class="label">Precision</div></div>
      <div class="stat"><div class="value">{analysis['recall']:.0%}</div><div class="label">Recall</div></div>
    </div>
    <div class="card">
      <b>Model:</b> HuggingFace Carbon |
      <b>Method:</b> Zero-shot log-likelihood ratio scoring |
      <b>Genes:</b> {', '.join(g.split('(')[0].strip() for g in results.keys())}
    </div>
    <div class="note">
      All 4 pipeline stages completed successfully. View individual task reports for detailed
      visualizations including DNA sequence tracks, variant lollipop plots, VEP score charts,
      confusion matrices, and ranked variant tables.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)

    log.info("Pipeline complete.")
    return scores_json, analysis_json

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_variant_effect/genomic_variant_effect.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# ///
```

## Orchestrate the pipeline

The `pipeline` task loads variants, scores them with Carbon, analyzes classification accuracy against known labels, and generates a summary report.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="genomic-variant-effect", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="genomic-variant-effect-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="genomic-variant-effect-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default gene variants — clinically relevant mutations
# ------------------------------------------------------------------
# Each entry: gene name -> { "sequence": reference DNA, "variants": [{ "pos": 0-indexed, "ref": base, "alt": base, "name": "...", "known_effect": "..." }] }
# Sequences are short windows (~120-200bp) around the variant site for tractable inference.

DEFAULT_GENE_VARIANTS = {
    "BRCA2 (Breast Cancer)": {
        "description": "Tumor suppressor critical for DNA repair via homologous recombination. Mutations dramatically increase breast and ovarian cancer risk.",
        "sequence": "ATGGCCTCGAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAGCAG",
        "variants": [
            {"pos": 12, "ref": "A", "alt": "T", "name": "c.37A>T", "known_effect": "pathogenic", "clinical": "Nonsense mutation — truncates protein early"},
            {"pos": 18, "ref": "G", "alt": "A", "name": "c.55G>A", "known_effect": "benign", "clinical": "Synonymous — no amino acid change"},
            {"pos": 30, "ref": "C", "alt": "T", "name": "c.91C>T", "known_effect": "pathogenic", "clinical": "Missense in DNA-binding domain"},
            {"pos": 45, "ref": "G", "alt": "C", "name": "c.136G>C", "known_effect": "uncertain", "clinical": "Variant of uncertain significance (VUS)"},
        ],
    },
    "TP53 (Tumor Suppressor)": {
        "description": "Guardian of the genome. Activates DNA repair, cell cycle arrest, and apoptosis. Mutated in >50% of human cancers.",
        "sequence": "ATGGAGGAGCCGCAGTCAGATCCTAGCGTGAGTTTGCACCCTTCAGAGACAGAAACCACTGGATTGGAGACTACTTCCTGAAACAACGTTCTGTCCCCCTTGCCGTCCCAAGCAATGGATGAT",
        "variants": [
            {"pos": 15, "ref": "C", "alt": "T", "name": "R175H", "known_effect": "pathogenic", "clinical": "Hotspot — gain-of-function, dominant negative. Most common TP53 mutation in cancer"},
            {"pos": 36, "ref": "T", "alt": "C", "name": "P72R", "known_effect": "benign", "clinical": "Common polymorphism — subtle effect on apoptosis efficiency"},
            {"pos": 54, "ref": "C", "alt": "A", "name": "G245S", "known_effect": "pathogenic", "clinical": "Contact mutant — disrupts DNA binding"},
            {"pos": 72, "ref": "T", "alt": "G", "name": "R248W", "known_effect": "pathogenic", "clinical": "Structural mutant — destabilizes DNA-binding loop"},
            {"pos": 90, "ref": "C", "alt": "T", "name": "R273H", "known_effect": "pathogenic", "clinical": "Contact mutant — directly contacts DNA bases"},
        ],
    },
    "CFTR (Cystic Fibrosis)": {
        "description": "Chloride channel protein. Mutations cause cystic fibrosis — the most common lethal genetic disease in people of European descent.",
        "sequence": "ATGCAGAGGTCGCCTCTGGAAAAGGCCAGCGTTGTCTCCAAACTTTTTTTCAGCTGGACCAGACCAATTTTGAGGAAAGGATACAGACAGCGCCTGGAATTGTCAGACATATACCAAATCCCTTC",
        "variants": [
            {"pos": 9, "ref": "G", "alt": "A", "name": "G85E", "known_effect": "pathogenic", "clinical": "Disrupts chloride channel processing"},
            {"pos": 24, "ref": "C", "alt": "T", "name": "R117H", "known_effect": "pathogenic", "clinical": "Reduces channel conductance — milder CF phenotype"},
            {"pos": 48, "ref": "T", "alt": "C", "name": "I148T", "known_effect": "benign", "clinical": "Previously misclassified — now known benign polymorphism"},
            {"pos": 66, "ref": "A", "alt": "G", "name": "R334W", "known_effect": "pathogenic", "clinical": "Gating mutation — channel opens less frequently"},
        ],
    },
    "KRAS (Oncogene)": {
        "description": "GTPase signal switch. KRAS mutations are the most common oncogenic driver — found in ~25% of all human cancers, especially pancreatic, colorectal, and lung.",
        "sequence": "ATGACTGAATATAAACTTGTGGTAGTTGGAGCTGGTGGCGTAGGCAAGAGTGCCTTGACGATACAGCTAATTCAGAATCATTTTGTGGACGAATATGATCCAACAATAGAGGATTCCTACAGGAA",
        "variants": [
            {"pos": 34, "ref": "G", "alt": "T", "name": "G12V", "known_effect": "pathogenic", "clinical": "Locks KRAS in active state — constitutive proliferation signal"},
            {"pos": 35, "ref": "G", "alt": "A", "name": "G12D", "known_effect": "pathogenic", "clinical": "Most common KRAS mutation in pancreatic cancer"},
            {"pos": 37, "ref": "G", "alt": "T", "name": "G13D", "known_effect": "pathogenic", "clinical": "Constitutively active — common in colorectal cancer"},
            {"pos": 60, "ref": "C", "alt": "A", "name": "Q61K", "known_effect": "pathogenic", "clinical": "Impairs GTP hydrolysis — locked ON state"},
        ],
    },
    "HBB (Sickle Cell)": {
        "description": "Beta-globin subunit of hemoglobin. The sickle cell mutation (E6V) is the most well-known single-base disease variant in humans.",
        "sequence": "ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTGTGGGGCAAGGTGAACGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTGGTCTACCCTTGGACCCAGAGG",
        "variants": [
            {"pos": 17, "ref": "A", "alt": "T", "name": "E6V (HbS)", "known_effect": "pathogenic", "clinical": "THE sickle cell mutation — causes hemoglobin polymerization under low O2"},
            {"pos": 19, "ref": "G", "alt": "A", "name": "E6K (HbC)", "known_effect": "pathogenic", "clinical": "Hemoglobin C disease — milder than sickle cell but causes crystal formation"},
            {"pos": 36, "ref": "G", "alt": "A", "name": "E26K", "known_effect": "benign", "clinical": "Hemoglobin E — most common Hb variant worldwide, mild effect"},
            {"pos": 78, "ref": "C", "alt": "T", "name": "Q39X", "known_effect": "pathogenic", "clinical": "Nonsense — causes beta-thalassemia (no functional beta-globin)"},
        ],
    },
}

# DNA base colors (classic genomics color scheme)
BASE_COLORS = {"A": "#2ecc71", "T": "#e74c3c", "G": "#f39c12", "C": "#3498db"}
BASE_COMPLEMENT = {"A": "T", "T": "A", "G": "C", "C": "G"}

# Pathogenicity color scheme
EFFECT_COLORS = {
    "pathogenic": "#dc2626",
    "benign": "#059669",
    "uncertain": "#f59e0b",
}
EFFECT_BADGES = {
    "pathogenic": "badge-danger",
    "benign": "badge-success",
    "uncertain": "badge-warning",
}

# ------------------------------------------------------------------
# Report styling — genomics-themed deep blues and teals
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #1e3a5f; border-bottom: 2px solid #2563eb; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #1e40af; margin-top: 20px; }
  .report .card { background: #eff6ff; border: 1px solid #bfdbfe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #dbeafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #1e3a5f; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #1e3a5f; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dbeafe; }
  .report tr:nth-child(even) { background: #eff6ff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-warning { background: #fef3c7; color: #92400e; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #dbeafe; color: #1e40af; }
  .report .chart-container { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #eff6ff; border-left: 4px solid #2563eb; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .gene-card { background: #fff; border: 1px solid #dbeafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .dna-track { font-family: 'SF Mono', 'Fira Code', monospace; letter-spacing: 1px; }
</style>
"""

def _wrap_report(html: str) -> str:
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    value_format: str = ".2f",
) -> str:
    """Generate an SVG grouped bar chart."""
    if not labels:
        return ""

    default_colors = ["#2563eb", "#1e3a5f", "#3b82f6", "#60a5fa", "#93c5fd"]
    colors = colors or default_colors

    ml, mr, mt, mb = 70, 20, 40, 80
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(abs(v) for v in all_vals) if all_vals else 1
    y_min = min(all_vals) if all_vals else 0
    # For VEP scores (negative = more damaging), we need to handle negative values
    if y_min >= 0:
        y_min_plot = 0
        y_max_plot = y_max * 1.15 or 1
    else:
        y_max_plot = max(y_max * 1.15, 0.1)
        y_min_plot = y_min * 1.15

    y_range = y_max_plot - y_min_plot or 1

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v - y_min_plot) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_min_plot + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_format}}</text>'
        )

    # Zero line
    if y_min_plot < 0 < y_max_plot:
        zy = sy(0)
        svg.append(
            f'<line x1="{ml}" y1="{zy:.1f}" x2="{ml + cw}" y2="{zy:.1f}" '
            f'stroke="#374151" stroke-width="1.5"/>'
        )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            if val >= 0:
                by = sy(val)
                bh = sy(0) - by if y_min_plot < 0 else mt + ch - by
            else:
                by = sy(0) if y_min_plot < 0 else mt + ch
                bh = sy(val) - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{max(0, bh):.1f}" fill="{color}" rx="2"/>'
            )
            text_y = by - 4 if val >= 0 else by + bh + 12
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{text_y:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_format}}</text>'
            )
        # Rotated group label
        lx = gx + n_series * bar_width / 2
        svg.append(
            f'<text x="{lx:.1f}" y="{mt + ch + 14}" '
            f'text-anchor="end" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {lx:.1f}, {mt + ch + 14})">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 110
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 110}" y="{mt + ch + 55}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 110 + 16}" y="{mt + ch + 66}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    width: int = 700,
    height: int = 500,
    value_format: str = ".2f",
    diverging: bool = False,
) -> str:
    """Generate an SVG heatmap. If diverging=True, uses red-white-blue scale centered at 0."""
    n_rows = len(matrix)
    n_cols = len(matrix[0]) if matrix else 0
    if not n_rows or not n_cols:
        return ""

    show_values = n_rows <= 10 and n_cols <= 12

    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)

    if diverging:
        abs_max = max(abs(v_min), abs(v_max)) or 1

        def get_color(v):
            t = v / abs_max  # -1 to 1
            if t < 0:
                # White to red (negative = damaging)
                r = 255
                g = int(255 * (1 + t))
                b = int(255 * (1 + t))
            else:
                # White to blue (positive = benign)
                r = int(255 * (1 - t))
                g = int(255 * (1 - t))
                b = 255
            return f"rgb({r},{g},{b})"
    else:
        v_range = v_max - v_min or 1

        def get_color(v):
            t = (v - v_min) / v_range
            r = int(255 - t * (255 - 30))
            g = int(255 - t * (255 - 58))
            b = int(255 - t * (255 - 95))
            return f"rgb({r},{g},{b})"

    # Layout
    ml = max(140, max(len(l) for l in row_labels) * 7 + 20) if row_labels else 140
    mr = 20
    mt = 80 if col_labels else 40
    mb = 30
    cw = width - ml - mr
    ch = height - mt - mb

    cell_w = cw / n_cols
    cell_h = ch / n_rows

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Column labels (rotated)
    for j, label in enumerate(col_labels):
        cx = ml + j * cell_w + cell_w / 2
        svg.append(
            f'<text x="{cx:.1f}" y="{mt - 8}" text-anchor="end" '
            f'font-size="10" fill="#374151" '
            f'transform="rotate(-45, {cx:.1f}, {mt - 8})">{label}</text>'
        )

    # Row labels + cells
    for i, row_label in enumerate(row_labels):
        ry = mt + i * cell_h + cell_h / 2
        svg.append(
            f'<text x="{ml - 8}" y="{ry + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#374151">{row_label}</text>'
        )
        for j in range(n_cols):
            val = matrix[i][j]
            color = get_color(val)
            cx = ml + j * cell_w
            cy = mt + i * cell_h
            svg.append(
                f'<rect x="{cx:.1f}" y="{cy:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{color}" stroke="#fff" stroke-width="1"/>'
            )
            if show_values:
                if diverging:
                    t = abs(val) / (max(abs(v_min), abs(v_max)) or 1)
                else:
                    t = (val - v_min) / (v_max - v_min or 1)
                text_color = "#fff" if t > 0.55 else "#1a1a2e"
                font_size = min(10, int(cell_w / 4), int(cell_h / 2.5))
                font_size = max(7, font_size)
                svg.append(
                    f'<text x="{cx + cell_w / 2:.1f}" y="{cy + cell_h / 2 + 3:.1f}" '
                    f'text-anchor="middle" font-size="{font_size}" '
                    f'fill="{text_color}">{val:{value_format}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_dna_track(
    sequence: str,
    variants: list[dict],
    gene_name: str = "",
    width: int = 900,
) -> str:
    """Render a color-coded DNA sequence track with variant positions highlighted."""
    chars_per_line = 60
    char_w = 11
    line_h = 22
    label_w = 50
    n_lines = (len(sequence) + chars_per_line - 1) // chars_per_line

    # Extra space for variant annotations
    variant_positions = {v["pos"] for v in variants}
    svg_h = n_lines * line_h + 60

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {svg_h}" '
        f'style="width:100%;max-width:{width}px;height:auto;font-family:monospace;">',
        f'<rect width="{width}" height="{svg_h}" fill="#fff" rx="6"/>',
    ]

    if gene_name:
        svg.append(
            f'<text x="{width / 2}" y="16" text-anchor="middle" '
            f'font-size="12" font-weight="600" fill="#1e3a5f">{gene_name}</text>'
        )

    y_offset = 28

    for line_idx in range(n_lines):
        start = line_idx * chars_per_line
        end = min(start + chars_per_line, len(sequence))
        chunk = sequence[start:end]
        y = y_offset + line_idx * line_h

        # Position label
        svg.append(
            f'<text x="{label_w - 8}" y="{y}" text-anchor="end" '
            f'font-size="9" fill="#9ca3af">{start + 1}</text>'
        )

        for ci, base in enumerate(chunk):
            abs_pos = start + ci
            x = label_w + ci * char_w
            color = BASE_COLORS.get(base, "#6b7280")
            is_variant = abs_pos in variant_positions

            if is_variant:
                # Red highlight for variant positions
                svg.append(
                    f'<rect x="{x - 1}" y="{y - 13}" width="{char_w + 2}" height="{line_h}" '
                    f'fill="#fee2e2" stroke="#dc2626" stroke-width="1.5" rx="2"/>'
                )
                # Small triangle marker above
                svg.append(
                    f'<polygon points="{x + char_w / 2:.1f},{y - 16} {x + char_w / 2 - 3:.1f},{y - 20} {x + char_w / 2 + 3:.1f},{y - 20}" '
                    f'fill="#dc2626"/>'
                )
            else:
                # Subtle background
                svg.append(
                    f'<rect x="{x}" y="{y - 12}" width="{char_w}" height="{line_h - 2}" '
                    f'fill="{color}" opacity="0.08" rx="1"/>'
                )

            svg.append(
                f'<text x="{x + char_w / 2:.1f}" y="{y}" text-anchor="middle" '
                f'font-size="11" font-weight="{"700" if is_variant else "500"}" '
                f'fill="{color}">{base}</text>'
            )

            # Spacer every 10 bases
            if (abs_pos + 1) % 10 == 0 and ci < len(chunk) - 1:
                svg.append(
                    f'<line x1="{x + char_w + 1}" y1="{y - 11}" '
                    f'x2="{x + char_w + 1}" y2="{y + 4}" '
                    f'stroke="#e5e7eb" stroke-width="0.5"/>'
                )

    # Legend
    ly = svg_h - 12
    lx = label_w
    for base, color in BASE_COLORS.items():
        svg.append(f'<rect x="{lx}" y="{ly - 8}" width="8" height="8" rx="1" fill="{color}"/>')
        svg.append(f'<text x="{lx + 12}" y="{ly}" font-size="9" fill="#6b7280">{base}</text>')
        lx += 30
    lx += 10
    svg.append(f'<rect x="{lx}" y="{ly - 8}" width="8" height="8" rx="1" fill="#fee2e2" stroke="#dc2626" stroke-width="1"/>')
    svg.append(f'<text x="{lx + 12}" y="{ly}" font-size="9" fill="#6b7280">Variant site</text>')

    svg.append("</svg>")
    return "\n".join(svg)

def _make_lollipop_plot(
    variants: list[dict],
    seq_length: int,
    title: str = "",
    width: int = 700,
    height: int = 200,
) -> str:
    """Generate a lollipop plot showing variant positions along a gene with effect scores."""
    if not variants:
        return ""

    ml, mr, mt, mb = 50, 30, 40, 40
    cw = width - ml - mr
    ch = height - mt - mb

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Gene track (horizontal bar)
    track_y = mt + ch * 0.7
    svg.append(
        f'<rect x="{ml}" y="{track_y - 4}" width="{cw}" height="8" '
        f'fill="#dbeafe" rx="4" stroke="#93c5fd" stroke-width="1"/>'
    )

    # Position labels at ends
    svg.append(
        f'<text x="{ml}" y="{track_y + 20}" text-anchor="middle" '
        f'font-size="9" fill="#9ca3af">1</text>'
    )
    svg.append(
        f'<text x="{ml + cw}" y="{track_y + 20}" text-anchor="middle" '
        f'font-size="9" fill="#9ca3af">{seq_length}</text>'
    )

    # Lollipops
    scores = [v.get("score", 0) for v in variants]
    max_score = max(abs(s) for s in scores) if scores else 1

    for v in variants:
        pos = v["pos"]
        score = v.get("score", 0)
        effect = v.get("known_effect", "uncertain")
        color = EFFECT_COLORS.get(effect, "#6b7280")

        x = ml + (pos / max(seq_length - 1, 1)) * cw
        stem_h = max(20, abs(score) / max_score * (ch * 0.6))
        circle_y = track_y - stem_h - 6

        # Stem
        svg.append(
            f'<line x1="{x:.1f}" y1="{track_y - 4}" x2="{x:.1f}" y2="{circle_y + 5:.1f}" '
            f'stroke="{color}" stroke-width="2"/>'
        )
        # Circle
        svg.append(
            f'<circle cx="{x:.1f}" cy="{circle_y:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5"/>'
        )
        # Label
        svg.append(
            f'<text x="{x:.1f}" y="{circle_y - 9:.1f}" text-anchor="middle" '
            f'font-size="8" fill="#374151" font-weight="600">{v.get("name", "")}</text>'
        )

    # Legend
    lx = ml
    for effect, color in EFFECT_COLORS.items():
        svg.append(f'<circle cx="{lx + 4}" cy="{height - 10}" r="4" fill="{color}"/>')
        svg.append(
            f'<text x="{lx + 12}" y="{height - 6}" font-size="9" fill="#374151">'
            f'{effect.title()}</text>'
        )
        lx += len(effect) * 7 + 24

    svg.append("</svg>")
    return "\n".join(svg)

def _make_score_comparison_chart(
    gene_results: dict,
    width: int = 800,
    height: int = 350,
) -> str:
    """Generate a dot plot comparing VEP scores across all genes, colored by known effect."""
    all_variants = []
    for gene_name, gene_data in gene_results.items():
        for v in gene_data["variants"]:
            all_variants.append({
                "gene": gene_name.split("(")[0].strip(),
                "name": v["name"],
                "score": v.get("score", 0),
                "known_effect": v["known_effect"],
            })

    if not all_variants:
        return ""

    ml, mr, mt, mb = 120, 30, 40, 30
    cw = width - ml - mr
    ch = height - mt - mb

    scores = [v["score"] for v in all_variants]
    s_min, s_max = min(scores), max(scores)
    s_pad = (s_max - s_min) * 0.1 or 0.5
    s_min -= s_pad
    s_max += s_pad
    s_range = s_max - s_min or 1

    def sx(s):
        return ml + (s - s_min) / s_range * cw

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
        f'<text x="{width / 2}" y="22" text-anchor="middle" '
        f'font-size="14" font-weight="600" fill="#1a1a2e">'
        f'Variant Effect Scores — All Genes</text>',
    ]

    # Grid
    for i in range(6):
        gx = s_min + s_range * i / 5
        px = sx(gx)
        svg.append(
            f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
            f'stroke="#f3f4f6" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 16}" text-anchor="middle" '
            f'font-size="10" fill="#9ca3af">{gx:.2f}</text>'
        )

    # Zero line
    if s_min < 0 < s_max:
        zx = sx(0)
        svg.append(
            f'<line x1="{zx:.1f}" y1="{mt}" x2="{zx:.1f}" y2="{mt + ch}" '
            f'stroke="#374151" stroke-width="1" stroke-dasharray="4,3"/>'
        )

    # Rows per variant
    row_h = ch / max(len(all_variants), 1)
    for i, v in enumerate(all_variants):
        y = mt + i * row_h + row_h / 2
        color = EFFECT_COLORS.get(v["known_effect"], "#6b7280")
        px = sx(v["score"])

        # Label
        label = f'{v["gene"]} {v["name"]}'
        svg.append(
            f'<text x="{ml - 8}" y="{y + 3:.1f}" text-anchor="end" '
            f'font-size="9" fill="#374151">{label}</text>'
        )
        # Connector line
        svg.append(
            f'<line x1="{ml}" y1="{y:.1f}" x2="{px:.1f}" y2="{y:.1f}" '
            f'stroke="{color}" stroke-width="1" opacity="0.3"/>'
        )
        # Dot
        svg.append(
            f'<circle cx="{px:.1f}" cy="{y:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1"/>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate gene variants
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def load_variants(
    variants_json: str = "",
) -> flyte.io.Dir:
    """Load gene variant definitions, validate sequences, and save to a temp directory."""
    if variants_json:
        genes = json.loads(variants_json)
    else:
        genes = DEFAULT_GENE_VARIANTS

    # Validate
    valid_bases = set("ATGC")
    for gene_name, gene_data in genes.items():
        seq = gene_data["sequence"].upper()
        invalid = set(seq) - valid_bases
        if invalid:
            log.warning(f"{gene_name}: invalid bases {invalid} — removing them")
            seq = "".join(b for b in seq if b in valid_bases)
            gene_data["sequence"] = seq

        for v in gene_data["variants"]:
            pos = v["pos"]
            if pos < 0 or pos >= len(seq):
                log.warning(f"{gene_name} variant {v['name']}: position {pos} out of range [0, {len(seq)})")
            elif seq[pos] != v["ref"]:
                log.warning(f"{gene_name} variant {v['name']}: expected ref={v['ref']} at pos {pos}, found {seq[pos]}")

    total_variants = sum(len(g["variants"]) for g in genes.values())
    log.info(f"Loaded {len(genes)} genes with {total_variants} variants")

    out_dir = tempfile.mkdtemp(prefix="genomic_vep_")
    with open(os.path.join(out_dir, "genes.json"), "w") as f:
        json.dump(genes, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Run Carbon model for variant effect scoring
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def score_variants(
    variants_dir: flyte.io.Dir,
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> str:
    """Score each variant using Carbon's log-likelihood ratio.

    For each variant, we compute:
        score = log P(alt_sequence) - log P(ref_sequence)

    A negative score means the model considers the variant less likely than
    the reference — suggestive of a damaging/pathogenic effect. A score near
    zero means the model sees little difference (likely benign).
    """
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info(f"Loading Carbon model: {model_name}")

    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu":
        log.warning("Running on CPU — inference will be slow. GPU recommended for production.")

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    ).to(device)
    model.eval()

    # Load variants
    variants_path = await variants_dir.download()
    with open(os.path.join(variants_path, "genes.json")) as f:
        genes = json.load(f)

    results = {}
    total_variants = sum(len(g["variants"]) for g in genes.values())
    scored = 0

    progress_html = """
    <h2>Carbon Variant Effect Scoring</h2>
    <div class="card">
        <b>Model:</b> {model}<br>
        <b>Device:</b> {device}<br>
        <b>Progress:</b> {scored}/{total} variants scored
    </div>
    """

    for gene_name, gene_data in genes.items():
        ref_seq = gene_data["sequence"]
        gene_results = {
            "description": gene_data.get("description", ""),
            "sequence": ref_seq,
            "variants": [],
        }

        # Score reference sequence
        ref_prompt = f"<dna>{ref_seq}"
        ref_inputs = tokenizer(ref_prompt, return_tensors="pt", add_special_tokens=False).to(device)

        with torch.no_grad():
            ref_output = model(**ref_inputs, labels=ref_inputs["input_ids"])
            ref_loss = ref_output.loss.item()
            ref_ll = -ref_loss * ref_inputs["input_ids"].shape[1]

        for variant in gene_data["variants"]:
            scored += 1
            await flyte.report.replace.aio(
                _wrap_report(progress_html.format(
                    model=model_name, device=device, scored=scored, total=total_variants
                )),
                do_flush=True,
            )

            # Create mutant sequence
            pos = variant["pos"]
            alt_seq = ref_seq[:pos] + variant["alt"] + ref_seq[pos + 1:]

            # Score mutant
            alt_prompt = f"<dna>{alt_seq}"
            alt_inputs = tokenizer(alt_prompt, return_tensors="pt", add_special_tokens=False).to(device)

            with torch.no_grad():
                alt_output = model(**alt_inputs, labels=alt_inputs["input_ids"])
                alt_loss = alt_output.loss.item()
                alt_ll = -alt_loss * alt_inputs["input_ids"].shape[1]

            # VEP score: positive = model prefers alt (likely benign), negative = model prefers ref (likely pathogenic)
            vep_score = alt_ll - ref_ll

            gene_results["variants"].append({
                **variant,
                "score": round(vep_score, 4),
                "ref_ll": round(ref_ll, 4),
                "alt_ll": round(alt_ll, 4),
            })

            log.info(
                f"  {gene_name} | {variant['name']}: score={vep_score:.4f} "
                f"(known: {variant['known_effect']})"
            )

        results[gene_name] = gene_results

    # Generate scoring report
    html_parts = [
        "<h2>Carbon Variant Effect Scoring</h2>",
        '<div class="stat-grid">',
        f'<div class="stat"><div class="value">{len(genes)}</div><div class="label">Genes</div></div>',
        f'<div class="stat"><div class="value">{total_variants}</div><div class="label">Variants Scored</div></div>',
        f'<div class="stat"><div class="value">{model_name.split("/")[-1]}</div><div class="label">Model</div></div>',
        f'<div class="stat"><div class="value">{device.upper()}</div><div class="label">Device</div></div>',
        "</div>",
    ]

    # Per-gene tables
    for gene_name, gene_data in results.items():
        html_parts.append(f'<h3>{gene_name}</h3>')
        html_parts.append(f'<div class="note">{gene_data["description"]}</div>')
        html_parts.append("<table><tr><th>Variant</th><th>Ref</th><th>Alt</th><th>VEP Score</th><th>Known Effect</th><th>Clinical</th></tr>")

        for v in gene_data["variants"]:
            badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
            direction = "damaging" if v["score"] < -0.1 else "neutral" if abs(v["score"]) <= 0.1 else "tolerated"
            html_parts.append(
                f'<tr>'
                f'<td><b>{v["name"]}</b></td>'
                f'<td style="color:{BASE_COLORS.get(v["ref"], "#333")};font-weight:700">{v["ref"]}</td>'
                f'<td style="color:{BASE_COLORS.get(v["alt"], "#333")};font-weight:700">{v["alt"]}</td>'
                f'<td><b>{v["score"]:.4f}</b> <span style="font-size:0.8em;color:#6c757d">({direction})</span></td>'
                f'<td><span class="badge {badge}">{v["known_effect"]}</span></td>'
                f'<td style="font-size:0.85em">{v.get("clinical", "")}</td>'
                f'</tr>'
            )
        html_parts.append("</table>")

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)

    return json.dumps(results)

# ------------------------------------------------------------------
# Task 3: Analyze and visualize variant effects
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def analyze_effects(
    scores_json: str,
    variants_dir: flyte.io.Dir,
) -> str:
    """Analyze VEP scores: classification accuracy, gene-level summaries, and rich visualizations."""
    results = json.loads(scores_json)

    variants_path = await variants_dir.download()
    with open(os.path.join(variants_path, "genes.json")) as f:
        genes = json.load(f)

    html_parts = ["<h2>Variant Effect Analysis</h2>"]

    # ------------------------------------------------------------------
    # Overall accuracy: does the model's score direction match known labels?
    # ------------------------------------------------------------------
    all_variants = []
    correct = 0
    total_known = 0
    true_pos = 0
    false_pos = 0
    true_neg = 0
    false_neg = 0

    for gene_name, gene_data in results.items():
        for v in gene_data["variants"]:
            all_variants.append({**v, "gene": gene_name})
            if v["known_effect"] in ("pathogenic", "benign"):
                total_known += 1
                predicted_pathogenic = v["score"] < -0.05
                actual_pathogenic = v["known_effect"] == "pathogenic"
                if predicted_pathogenic == actual_pathogenic:
                    correct += 1
                if predicted_pathogenic and actual_pathogenic:
                    true_pos += 1
                elif predicted_pathogenic and not actual_pathogenic:
                    false_pos += 1
                elif not predicted_pathogenic and actual_pathogenic:
                    false_neg += 1
                else:
                    true_neg += 1

    accuracy = correct / total_known if total_known else 0
    precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) else 0
    recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) else 0

    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{accuracy:.0%}</div><div class="label">Direction Accuracy</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{precision:.0%}</div><div class="label">Precision (Pathogenic)</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{recall:.0%}</div><div class="label">Recall (Pathogenic)</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{len(all_variants)}</div><div class="label">Total Variants</div></div>')
    html_parts.append("</div>")

    html_parts.append(
        '<div class="note">'
        "<b>How to read VEP scores:</b> Negative scores mean Carbon considers the variant "
        "less likely than the reference sequence — suggestive of a damaging effect. Scores near "
        "zero indicate the model sees little difference (likely benign). The magnitude indicates "
        "confidence."
        "</div>"
    )

    # ------------------------------------------------------------------
    # Cross-gene score comparison dot plot
    # ------------------------------------------------------------------
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_score_comparison_chart(results))
    html_parts.append("</div>")

    # ------------------------------------------------------------------
    # Per-gene visualizations
    # ------------------------------------------------------------------
    for gene_name, gene_data in results.items():
        short_name = gene_name.split("(")[0].strip()
        html_parts.append(f'<h3>{gene_name}</h3>')
        html_parts.append(f'<div class="note">{gene_data["description"]}</div>')

        # DNA track with variant positions highlighted
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_dna_track(
            gene_data["sequence"],
            gene_data["variants"],
            gene_name=f"{short_name} Reference Sequence",
        ))
        html_parts.append("</div>")

        # Lollipop plot
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_lollipop_plot(
            gene_data["variants"],
            len(gene_data["sequence"]),
            title=f"{short_name} — Variant Positions & Effect Scores",
        ))
        html_parts.append("</div>")

        # Score bar chart for this gene
        variant_names = [v["name"] for v in gene_data["variants"]]
        variant_scores = [v["score"] for v in gene_data["variants"]]
        html_parts.append('<div class="chart-container">')
        html_parts.append(_make_bar_chart(
            variant_names,
            {"VEP Score": variant_scores},
            title=f"{short_name} — Log-Likelihood Ratio Scores",
            colors=[EFFECT_COLORS.get(v["known_effect"], "#6b7280") for v in gene_data["variants"]],
            value_format=".3f",
        ))
        html_parts.append("</div>")

        # Variant detail cards
        for v in gene_data["variants"]:
            badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
            score_color = "#dc2626" if v["score"] < -0.1 else "#059669" if v["score"] > 0.05 else "#f59e0b"
            html_parts.append(
                f'<div class="gene-card">'
                f'<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:8px;">'
                f'<b style="font-size:1.1em">{v["name"]}</b>'
                f'<span class="badge {badge}">{v["known_effect"]}</span>'
                f'</div>'
                f'<div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:8px;">'
                f'<div><span style="color:#6c757d;font-size:0.85em">Ref base:</span> '
                f'<b style="color:{BASE_COLORS.get(v["ref"], "#333")}">{v["ref"]}</b></div>'
                f'<div><span style="color:#6c757d;font-size:0.85em">Alt base:</span> '
                f'<b style="color:{BASE_COLORS.get(v["alt"], "#333")}">{v["alt"]}</b></div>'
                f'<div><span style="color:#6c757d;font-size:0.85em">VEP Score:</span> '
                f'<b style="color:{score_color}">{v["score"]:.4f}</b></div>'
                f'</div>'
                f'<div style="margin-top:8px;font-size:0.9em;color:#374151">{v.get("clinical", "")}</div>'
                f'</div>'
            )

    # ------------------------------------------------------------------
    # Confusion matrix as heatmap
    # ------------------------------------------------------------------
    html_parts.append("<h3>Classification Performance</h3>")
    html_parts.append(
        '<div class="note">'
        "Using a simple threshold (score &lt; -0.05 = predicted pathogenic). "
        "This is zero-shot — no training on these specific variants."
        "</div>"
    )

    conf_matrix = [[true_pos, false_neg], [false_pos, true_neg]]
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        conf_matrix,
        ["Actual Pathogenic", "Actual Benign"],
        ["Predicted Pathogenic", "Predicted Benign"],
        title="Confusion Matrix (Known Variants Only)",
        value_format=".0f",
        width=400,
        height=300,
    ))
    html_parts.append("</div>")

    # ------------------------------------------------------------------
    # Score distribution by known effect
    # ------------------------------------------------------------------
    html_parts.append("<h3>Score Distribution by Known Effect</h3>")

    pathogenic_scores = [v["score"] for v in all_variants if v["known_effect"] == "pathogenic"]
    benign_scores = [v["score"] for v in all_variants if v["known_effect"] == "benign"]
    uncertain_scores = [v["score"] for v in all_variants if v["known_effect"] == "uncertain"]

    stats_html = '<div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:12px;">'
    for label, scores, color in [
        ("Pathogenic", pathogenic_scores, "#dc2626"),
        ("Benign", benign_scores, "#059669"),
        ("Uncertain", uncertain_scores, "#f59e0b"),
    ]:
        if scores:
            mean_s = sum(scores) / len(scores)
            min_s = min(scores)
            max_s = max(scores)
            stats_html += (
                f'<div class="gene-card" style="border-left:4px solid {color}">'
                f'<b style="color:{color}">{label}</b> (n={len(scores)})<br>'
                f'Mean: {mean_s:.4f}<br>'
                f'Range: [{min_s:.4f}, {max_s:.4f}]'
                f'</div>'
            )
    stats_html += "</div>"
    html_parts.append(stats_html)

    # Summary
    analysis = {
        "total_variants": len(all_variants),
        "total_known": total_known,
        "accuracy": round(accuracy, 4),
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "true_pos": true_pos,
        "false_pos": false_pos,
        "true_neg": true_neg,
        "false_neg": false_neg,
        "pathogenic_mean_score": round(sum(pathogenic_scores) / len(pathogenic_scores), 4) if pathogenic_scores else None,
        "benign_mean_score": round(sum(benign_scores) / len(benign_scores), 4) if benign_scores else None,
    }

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)
    return json.dumps(analysis)

# ------------------------------------------------------------------
# Task 4: Generate comprehensive summary report
# ------------------------------------------------------------------

@cpu_env.task(report=True)
async def generate_summary(
    scores_json: str,
    analysis_json: str,
) -> str:
    """Generate the final summary report combining all results."""
    results = json.loads(scores_json)
    analysis = json.loads(analysis_json)

    html_parts = [
        "<h2>Genomic Variant Effect Prediction — Summary</h2>",
        '<div class="note">'
        "This pipeline uses <b>HuggingFace Carbon</b>, an autoregressive genomic foundation model "
        "trained on 1 trillion tokens of DNA sequence, to perform <b>zero-shot variant effect "
        "prediction</b>. No fine-tuning or labeled training data was used — the model scores "
        "variants purely based on its learned understanding of DNA sequence grammar."
        "</div>",
    ]

    # Key metrics
    html_parts.append('<div class="stat-grid">')
    html_parts.append(f'<div class="stat"><div class="value">{len(results)}</div><div class="label">Genes Analyzed</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["total_variants"]}</div><div class="label">Variants Scored</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["accuracy"]:.0%}</div><div class="label">Direction Accuracy</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["precision"]:.0%}</div><div class="label">Precision</div></div>')
    html_parts.append(f'<div class="stat"><div class="value">{analysis["recall"]:.0%}</div><div class="label">Recall</div></div>')
    html_parts.append("</div>")

    # Gene summary table
    html_parts.append("<h3>Per-Gene Summary</h3>")
    html_parts.append(
        "<table><tr><th>Gene</th><th>Variants</th><th>Mean Score</th>"
        "<th>Pathogenic</th><th>Benign</th><th>Uncertain</th></tr>"
    )

    for gene_name, gene_data in results.items():
        variants = gene_data["variants"]
        scores = [v["score"] for v in variants]
        mean_score = sum(scores) / len(scores) if scores else 0
        n_path = sum(1 for v in variants if v["known_effect"] == "pathogenic")
        n_benign = sum(1 for v in variants if v["known_effect"] == "benign")
        n_unc = sum(1 for v in variants if v["known_effect"] == "uncertain")
        short = gene_name.split("(")[0].strip()

        html_parts.append(
            f"<tr><td><b>{short}</b></td><td>{len(variants)}</td>"
            f"<td>{mean_score:.4f}</td>"
            f'<td><span class="badge badge-danger">{n_path}</span></td>'
            f'<td><span class="badge badge-success">{n_benign}</span></td>'
            f'<td><span class="badge badge-warning">{n_unc}</span></td></tr>'
        )
    html_parts.append("</table>")

    # Cross-gene heatmap: gene x metric
    gene_names = [g.split("(")[0].strip() for g in results.keys()]
    metrics = ["Mean Score", "Min Score", "Max Score", "# Variants"]
    matrix = []
    for gene_data in results.values():
        scores = [v["score"] for v in gene_data["variants"]]
        matrix.append([
            sum(scores) / len(scores) if scores else 0,
            min(scores) if scores else 0,
            max(scores) if scores else 0,
            len(scores),
        ])

    html_parts.append("<h3>Gene-Level Metrics</h3>")
    html_parts.append('<div class="chart-container">')
    html_parts.append(_make_heatmap(
        matrix,
        gene_names,
        metrics,
        title="Gene-Level VEP Score Summary",
        value_format=".2f",
        width=600,
        height=350,
    ))
    html_parts.append("</div>")

    # All variants ranked by score (most damaging first)
    html_parts.append("<h3>All Variants Ranked by Impact</h3>")
    all_vars_sorted = []
    for gene_name, gene_data in results.items():
        for v in gene_data["variants"]:
            all_vars_sorted.append({**v, "gene": gene_name.split("(")[0].strip()})
    all_vars_sorted.sort(key=lambda x: x["score"])

    html_parts.append(
        "<table><tr><th>#</th><th>Gene</th><th>Variant</th><th>Score</th>"
        "<th>Known</th><th>Clinical Significance</th></tr>"
    )
    for i, v in enumerate(all_vars_sorted):
        badge = EFFECT_BADGES.get(v["known_effect"], "badge-info")
        score_color = "#dc2626" if v["score"] < -0.1 else "#059669" if v["score"] > 0.05 else "#f59e0b"
        html_parts.append(
            f'<tr><td>{i + 1}</td><td><b>{v["gene"]}</b></td><td>{v["name"]}</td>'
            f'<td style="color:{score_color};font-weight:700">{v["score"]:.4f}</td>'
            f'<td><span class="badge {badge}">{v["known_effect"]}</span></td>'
            f'<td style="font-size:0.85em">{v.get("clinical", "")}</td></tr>'
        )
    html_parts.append("</table>")

    # Method note
    html_parts.append(
        '<div class="note">'
        "<b>Method:</b> Zero-shot variant effect prediction using log-likelihood ratio scoring. "
        "For each variant, we compute score = log P(mutant sequence | Carbon) - log P(reference sequence | Carbon). "
        "Negative scores indicate the model considers the mutant less probable than the reference, "
        "which correlates with pathogenicity. This approach requires no fine-tuning and generalizes "
        "across genes and variant types.<br><br>"
        "<b>Limitations:</b> These are short sequence windows — real clinical VEP would use longer "
        "genomic context (Carbon supports up to 786kbp). The threshold for pathogenicity classification "
        "(-0.05) is a simple heuristic; clinical use requires calibrated thresholds per gene."
        "</div>"
    )

    await flyte.report.replace.aio(_wrap_report("\n".join(html_parts)), do_flush=True)
    return json.dumps({"status": "complete", "analysis": analysis})

# ------------------------------------------------------------------
# Pipeline orchestrator
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    variants_json: str = "",
    model_name: str = "HuggingFaceBio/Carbon-3B",
) -> tuple[str, str]:
    """
    End-to-end genomic variant effect prediction pipeline.

    Returns (scores JSON, analysis JSON).

    1. Load and validate gene variants
    2. Score variants with Carbon (log-likelihood ratio)
    3. Analyze effects — accuracy, visualizations, classification
    4. Generate comprehensive summary report
    """
    log.info("Starting genomic variant effect prediction pipeline...")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = [
            "Load Variants",
            "Carbon Scoring",
            "Analyze Effects",
            "Generate Summary",
        ]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#2563eb;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#2563eb;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>Genomic Variant Effect Prediction</h2>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    # Stage 1: Load variants
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Loading and validating gene variants...")),
        do_flush=True,
    )
    var_dir = await load_variants(variants_json=variants_json)

    # Stage 2: Score with Carbon
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Running Carbon model for variant effect scoring...")),
        do_flush=True,
    )
    scores_json = await score_variants(variants_dir=var_dir, model_name=model_name)

    # Stage 3: Analyze effects
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Analyzing variant effects and generating visualizations...")),
        do_flush=True,
    )
    analysis_json = await analyze_effects(scores_json=scores_json, variants_dir=var_dir)

    # Stage 4: Summary
    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Generating comprehensive summary report...")),
        do_flush=True,
    )
    summary_json = await generate_summary(scores_json=scores_json, analysis_json=analysis_json)

    # Final pipeline report
    analysis = json.loads(analysis_json)
    results = json.loads(scores_json)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(results)}</div><div class="label">Genes Analyzed</div></div>
      <div class="stat"><div class="value">{analysis['total_variants']}</div><div class="label">Variants Scored</div></div>
      <div class="stat"><div class="value">{analysis['accuracy']:.0%}</div><div class="label">Direction Accuracy</div></div>
      <div class="stat"><div class="value">{analysis['precision']:.0%}</div><div class="label">Precision</div></div>
      <div class="stat"><div class="value">{analysis['recall']:.0%}</div><div class="label">Recall</div></div>
    </div>
    <div class="card">
      <b>Model:</b> HuggingFace Carbon |
      <b>Method:</b> Zero-shot log-likelihood ratio scoring |
      <b>Genes:</b> {', '.join(g.split('(')[0].strip() for g in results.keys())}
    </div>
    <div class="note">
      All 4 pipeline stages completed successfully. View individual task reports for detailed
      visualizations including DNA sequence tracks, variant lollipop plots, VEP score charts,
      confusion matrices, and ranked variant tables.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)

    log.info("Pipeline complete.")
    return scores_json, analysis_json

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/genomic_variant_effect/genomic_variant_effect.py*

## Run the workflow

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/genomic_variant_effect):

```
cd v2/tutorials/genomic_variant_effect
uv run --script genomic_variant_effect.py
```

Use a smaller Carbon model for faster iteration:

```
flyte run genomic_variant_effect.py pipeline --model_name HuggingFaceBio/Carbon-500M
```

Negative VEP scores indicate the model prefers the reference allele over the alternate — a signal correlated with pathogenicity in this zero-shot setup.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/drug-molecule-screening ===

# Drug molecule screening agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/drug_molecule_screening).

This tutorial builds an **agentic** virtual drug-screening workflow on Flyte. A medicinal-chemistry agent interprets your therapeutic goal in plain language, derives screening criteria, and composes durable RDKit stage tasks — while the scientific core (property computation, Lipinski filters, Tanimoto similarity, ranking, and HTML reports) stays in trusted, deterministic tools.

The pattern follows how cheminformatics agents like ChemCrow and PharmAgents are built: **the LLM plans and reflects; RDKit computes.**

Flyte provides:

- **Flyte-native agent orchestration** via `flyte.ai.agents.Agent` — see [Flyte-native agents](https://www.union.ai/docs/v2/union/user-guide/build-agent/flyte-agents/page.md)
- **Typed agent tool I/O** — Flyte 2.5.4+ passes `flyte.io.Dir`, `File`, and `DataFrame` between agent tool calls so the LLM can compose multi-step pipelines directly
- **Cached molecule loading** so repeated runs skip re-parsing SMILES
- **Report-enabled stage tasks** that stream property charts, similarity matrices, and candidate spotlights as each step completes
- **Hybrid iteration** — the agent re-runs `screen_candidates` and `generate_report` with adjusted criteria when the funnel is too narrow, reusing cached `molecule_dir` and `properties_json`

> [!NOTE] Prerequisites
> Create an Anthropic API key secret (the key name must match the `TaskEnvironment`):
>
> ```
> flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
> ```
>
> See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

## Define the task environment

The pipeline runs on CPU with RDKit, LiteLLM, and system libraries for 2D structure rendering.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.4",
#    "litellm",
#    "rdkit",
#    "numpy",
#    "scikit-learn",
#    "pillow",
# ]
# main = "pipeline"
# params = ""
# ///
"""Virtual drug molecule screening — compute properties, apply Lipinski filters, rank candidates."""

import base64
import io
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report
from flyte.ai.agents import Agent, tool

MODEL = os.getenv("DRUG_SCREENING_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="drug-molecule-screening", pre=True).with_apt_packages(
    "libxrender1", "libxext6", "libexpat1",
)

env = flyte.TaskEnvironment(
    name="drug-molecule-screening",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default molecule library — real SMILES for well-known drugs
# ------------------------------------------------------------------

DEFAULT_MOLECULES = {
    "Aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
    "Ibuprofen": "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "Penicillin G": "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C",
    "Metformin": "CN(C)C(=N)NC(=N)N",
    "Paracetamol": "CC(=O)NC1=CC=C(C=C1)O",
    "Diazepam": "ClC1=CC2=C(C=C1)N(C(=O)CN=C2C3=CC=CC=C3)C",
    "Omeprazole": "CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=CC=CC=C3N2",
    "Atorvastatin": "CC(C)C1=C(C(=C(N1CCC(CC(CC(=O)O)O)O)C2=CC=C(C=C2)F)C3=CC=CC=C3)C(=O)NC4=CC=CC=C4",
    "Methotrexate": "CN(CC1=CN=C2N=C(N=C(N)C2=N1)N)C3=CC=C(C=C3)C(=O)NC(CCC(=O)O)C(=O)O",
    "Doxorubicin": "CC1C(C(CC(O1)OC2CC(CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
    "Tamoxifen": "CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3",
    "Lopinavir": "CC1=C(C(=CC=C1)C)OCC(=O)NC(CC2=CC=CC=C2)C(CC(CC3=CC=CC=C3)NC(=O)C(C(C)C)N4CCCNC4=O)O",
    "Remdesivir": "CCC(CC)COC(=O)C(C)NP(=O)(OCC1C(C(C(O1)C2=CC=C3N2N=CN=C3N)O)O)OC4=CC=CC=C4",
    "Erlotinib": "COCCOC1=CC2=C(C=C1OCCOC)C(=NC=N2)NC3=CC=CC(=C3)C#C",
}

# ------------------------------------------------------------------
# Report styling — pharma blue/cyan theme
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #0e4f6e; border-bottom: 2px solid #0891b2; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #155e75; margin-top: 20px; }
  .report .card { background: #ecfeff; border: 1px solid #a5f3fc; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #cffafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0e4f6e; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0e4f6e; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #cffafe; }
  .report tr:nth-child(even) { background: #ecfeff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #cffafe; color: #155e75; }
  .report .chart-container { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #ecfeff; border-left: 4px solid #0891b2; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .molecule-card { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .molecule-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 12px; margin: 16px 0; }
  .report .funnel { text-align: center; margin: 24px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _mol_to_data_uri(mol, size: tuple[int, int] = (300, 300)) -> str:
    """Convert an RDKit molecule to a PNG base64 data URI."""
    from rdkit.Chem import Draw

    img = Draw.MolToImage(mol, size=size)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f"data:image/png;base64,{b64}"

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 340,
    y_max_cap: float | None = None,
    horizontal: bool = False,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels.
        series: Dict mapping series name to list of values.
        title: Chart title.
        colors: Colors for each series.
        width/height: SVG dimensions.
        y_max_cap: Cap the y-axis at this value.
        horizontal: If True, draw horizontal bars.
        value_fmt: Format string for value labels.

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#a5f3fc", "#155e75"]
    colors = colors or default_colors

    if horizontal:
        return _make_horizontal_bar_chart(labels, series, title, colors, width, height, value_fmt)

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_fmt}}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_fmt}}</text>'
            )
        # Truncate long labels
        disp_label = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 16}" '
            f'text-anchor="middle" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {gx + n_series * bar_width / 2:.1f}, {mt + ch + 16})">'
            f'{disp_label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 100
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 100}" y="{mt + ch + 40}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 51}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_horizontal_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 400,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG horizontal bar chart (sorted)."""
    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0"]
    colors = colors or default_colors

    n = len(labels)
    row_height = max(22, min(35, (height - 80) // max(n, 1)))
    actual_height = max(height, 80 + n * row_height)
    ml, mr, mt, mb = 120, 60, 40, 20
    cw = width - ml - mr
    ch = actual_height - mt - mb

    # Use first series
    first_key = list(series.keys())[0]
    vals = series[first_key]
    x_max = max(vals) * 1.15 if vals else 1

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {actual_height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{actual_height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    bar_h = row_height * 0.65
    for i, (label, val) in enumerate(zip(labels, vals)):
        y = mt + i * row_height
        bw = (val / x_max) * cw if x_max else 0
        color = colors[i % len(colors)]
        # Label
        disp = label if len(label) <= 14 else label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + bar_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#1a1a2e">{disp}</text>'
        )
        # Bar
        svg.append(
            f'<rect x="{ml}" y="{y:.1f}" width="{bw:.1f}" height="{bar_h:.1f}" '
            f'fill="{color}" rx="3"/>'
        )
        # Value
        svg.append(
            f'<text x="{ml + bw + 6:.1f}" y="{y + bar_h / 2 + 4:.1f}" '
            f'font-size="11" fill="#0e4f6e" font-weight="600">{val:{value_fmt}}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    color_scale: str = "cyan",
    width: int = 700,
    height: int = 500,
    value_fmt: str = ".2f",
) -> str:
    """Generate an SVG heatmap.

    Args:
        matrix: 2D list of values (rows x cols).
        row_labels: Labels for rows.
        col_labels: Labels for columns.
        title: Chart title.
        color_scale: Color scheme ("cyan", "red", "green").
        width/height: SVG dimensions.
        value_fmt: Format string for cell values.

    Returns:
        SVG string.
    """
    if not matrix or not matrix[0]:
        return ""

    n_rows = len(matrix)
    n_cols = len(matrix[0])

    ml, mr, mt, mb = 110, 20, 70, 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    # Flatten to find range
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    def color_for(v):
        t = (v - v_min) / v_range
        if color_scale == "cyan":
            # White to deep teal
            r = int(255 - t * (255 - 14))
            g = int(255 - t * (255 - 79))
            b = int(255 - t * (255 - 110))
        elif color_scale == "red":
            r = int(255 - t * 50)
            g = int(255 - t * 200)
            b = int(255 - t * 200)
        else:  # green
            r = int(255 - t * 200)
            g = int(255 - t * 50)
            b = int(255 - t * 200)
        return f"rgb({r},{g},{b})"

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Column labels (rotated)
    for ci, label in enumerate(col_labels):
        x = ml + ci * cell_w + cell_w / 2
        disp = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{x:.1f}" y="{mt - 8}" text-anchor="end" font-size="10" '
            f'fill="#1a1a2e" transform="rotate(-45, {x:.1f}, {mt - 8})">{disp}</text>'
        )

    # Row labels + cells
    for ri, (row_label, row_vals) in enumerate(zip(row_labels, matrix)):
        y = mt + ri * cell_h
        disp = row_label if len(row_label) <= 14 else row_label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + cell_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#1a1a2e">{disp}</text>'
        )
        for ci, val in enumerate(row_vals):
            x = ml + ci * cell_w
            fill = color_for(val)
            svg.append(
                f'<rect x="{x:.1f}" y="{y:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{fill}" stroke="#fff" stroke-width="1"/>'
            )
            # Text color: dark on light, light on dark
            t = (val - v_min) / v_range
            txt_color = "#fff" if t > 0.55 else "#1a1a2e"
            # Only show text if cells are large enough
            if cell_w > 30 and cell_h > 18:
                svg.append(
                    f'<text x="{x + cell_w / 2:.1f}" y="{y + cell_h / 2 + 4:.1f}" '
                    f'text-anchor="middle" font-size="9" fill="{txt_color}">'
                    f'{val:{value_fmt}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_scatter_plot(
    points: list[dict],
    x_label: str = "MW",
    y_label: str = "LogP",
    title: str = "",
    reference_lines: list[dict] | None = None,
    width: int = 700,
    height: int = 400,
) -> str:
    """Generate an SVG scatter plot.

    Args:
        points: List of dicts with "x", "y", "label" keys.
        x_label/y_label: Axis labels.
        title: Chart title.
        reference_lines: List of dicts with "axis" ("x"/"y"), "value", "label".
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not points:
        return ""

    ml, mr, mt, mb = 60, 30, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [p["x"] for p in points]
    y_vals = [p["y"] for p in points]
    x_min, x_max = min(x_vals) * 0.9, max(x_vals) * 1.1
    y_min, y_max = min(y_vals) - 1, max(y_vals) + 1

    # Extend ranges to include reference lines
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                x_max = max(x_max, rl["value"] * 1.1)
            else:
                y_max = max(y_max, rl["value"] * 1.1)

    x_range = x_max - x_min or 1
    y_range = y_max - y_min or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid
    for i in range(6):
        y_tick = y_min + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.1f}</text>'
        )

    for i in range(6):
        x_tick = x_min + x_range * i / 5
        px = sx(x_tick)
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
            f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Reference lines (Lipinski boundaries)
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                px = sx(rl["value"])
                svg.append(
                    f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{px + 4:.1f}" y="{mt + 14}" font-size="10" '
                    f'fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )
            else:
                py = sy(rl["value"])
                svg.append(
                    f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{ml + cw - 4:.1f}" y="{py - 6:.1f}" text-anchor="end" '
                    f'font-size="10" fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )

    # Drug-like zone shading (MW<=500 and LogP<=5 quadrant)
    if reference_lines:
        mw_line = next((rl for rl in reference_lines if rl["axis"] == "x"), None)
        logp_line = next((rl for rl in reference_lines if rl["axis"] == "y"), None)
        if mw_line and logp_line:
            zx1 = sx(x_min)
            zx2 = sx(min(mw_line["value"], x_max))
            zy1 = sy(min(logp_line["value"], y_max))
            zy2 = sy(y_min)
            svg.append(
                f'<rect x="{zx1:.1f}" y="{zy1:.1f}" '
                f'width="{zx2 - zx1:.1f}" height="{zy2 - zy1:.1f}" '
                f'fill="#0891b2" opacity="0.06" rx="4"/>'
            )
            svg.append(
                f'<text x="{zx1 + 8:.1f}" y="{zy2 - 8:.1f}" font-size="11" '
                f'fill="#0891b2" font-weight="600" opacity="0.6">Drug-like Zone</text>'
            )

    # Points
    point_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#155e75", "#0284c7",
                    "#059669", "#0d9488", "#0369a1", "#047857", "#115e59",
                    "#0c4a6e", "#064e3b", "#1e3a5f", "#134e4a", "#075985"]
    for i, pt in enumerate(points):
        px, py = sx(pt["x"]), sy(pt["y"])
        color = point_colors[i % len(point_colors)]
        svg.append(
            f'<circle cx="{px:.1f}" cy="{py:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5" opacity="0.85"/>'
        )
        # Label offset to avoid overlap
        offset_x = 8
        offset_y = -8 if i % 2 == 0 else 14
        label = pt["label"] if len(pt["label"]) <= 12 else pt["label"][:10] + ".."
        svg.append(
            f'<text x="{px + offset_x:.1f}" y="{py + offset_y:.1f}" '
            f'font-size="9" fill="#1a1a2e">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Axis labels
    if x_label:
        svg.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        svg.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_funnel(
    stages: list[dict],
    title: str = "",
    width: int = 600,
    height: int = 400,
) -> str:
    """Generate an SVG funnel visualization.

    Args:
        stages: List of dicts with "label", "count", "total" keys.
        title: Chart title.
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not stages:
        return ""

    n = len(stages)
    mt = 50
    mb = 20
    available_h = height - mt - mb
    stage_h = available_h / n
    cx = width / 2

    # Color gradient from light cyan to deep teal
    colors = []
    for i in range(n):
        t = i / max(n - 1, 1)
        r = int(207 - t * (207 - 14))
        g = int(250 - t * (250 - 79))
        b = int(254 - t * (254 - 110))
        colors.append(f"rgb({r},{g},{b})")

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{cx}" y="28" text-anchor="middle" '
            f'font-size="16" font-weight="700" fill="#0e4f6e">{title}</text>'
        )

    max_count = stages[0]["count"] if stages else 1
    max_width = width * 0.75

    for i, stage in enumerate(stages):
        y_top = mt + i * stage_h
        y_bot = y_top + stage_h

        # Width proportional to count
        w_top = max_width * (stage["count"] / max_count) if i == 0 else prev_w_bot
        if i < n - 1:
            w_bot = max_width * (stages[i + 1]["count"] / max_count)
        else:
            w_bot = max_width * (stage["count"] / max_count) * 0.7

        prev_w_bot = w_bot

        # Trapezoid
        x1_top = cx - w_top / 2
        x2_top = cx + w_top / 2
        x1_bot = cx - w_bot / 2
        x2_bot = cx + w_bot / 2

        svg.append(
            f'<polygon points="{x1_top:.1f},{y_top:.1f} {x2_top:.1f},{y_top:.1f} '
            f'{x2_bot:.1f},{y_bot:.1f} {x1_bot:.1f},{y_bot:.1f}" '
            f'fill="{colors[i]}" stroke="#fff" stroke-width="2"/>'
        )

        # Text: dark on light, white on dark
        t = i / max(n - 1, 1)
        txt_color = "#0e4f6e" if t < 0.5 else "#fff"
        y_mid = (y_top + y_bot) / 2

        svg.append(
            f'<text x="{cx}" y="{y_mid - 4:.1f}" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="{txt_color}">{stage["label"]}</text>'
        )
        svg.append(
            f'<text x="{cx}" y="{y_mid + 14:.1f}" text-anchor="middle" '
            f'font-size="12" fill="{txt_color}" opacity="0.85">'
            f'{stage["count"]} / {stage["total"]}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate molecules
# ------------------------------------------------------------------

@tool
@env.task(cache="auto")
async def load_molecules(
    molecules_json: str = "",
) -> flyte.io.Dir:
    """Parse SMILES strings, validate with RDKit, generate 2D depictions.

    Args:
        molecules_json: JSON string mapping molecule names to SMILES.
            Defaults to a curated library of ~15 well-known drugs.

    Returns:
        flyte.io.Dir containing molecule data (JSON + PNG depictions).
        Pass this directory to compute_properties and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Draw

    if molecules_json.strip():
        molecules = json.loads(molecules_json)
    else:
        molecules = DEFAULT_MOLECULES

    out_dir = tempfile.mkdtemp(prefix="mol_library_")
    results = []
    valid_count = 0
    invalid_count = 0

    log.info(f"Parsing {len(molecules)} molecules...")

    for name, smiles in molecules.items():
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            log.warning(f"  [INVALID] {name}: {smiles}")
            invalid_count += 1
            continue

        valid_count += 1

        # Generate 2D depiction as PNG
        img = Draw.MolToImage(mol, size=(300, 300))
        img_path = os.path.join(out_dir, f"{name.replace(' ', '_')}.png")
        img.save(img_path)

        results.append({
            "name": name,
            "smiles": smiles,
            "valid": True,
            "image_file": os.path.basename(img_path),
        })

    # Save molecule manifest
    manifest = {
        "total": len(molecules),
        "valid": valid_count,
        "invalid": invalid_count,
        "molecules": results,
    }
    manifest_path = os.path.join(out_dir, "manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    log.info(f"Loaded {valid_count} valid molecules ({invalid_count} invalid)")

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Compute physicochemical properties
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def compute_properties(
    molecule_dir: flyte.io.Dir,
) -> str:
    """Compute drug-likeness properties for all molecules.

    Computes MW, LogP, HBD, HBA, TPSA, rotatable bonds, formal charge,
    ring count, QED, and Lipinski Rule of Five compliance.

    Args:
        molecule_dir: Directory from load_molecules.

    Returns:
        JSON string with all computed properties. Pass to screen_candidates
        and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Lipinski
    from rdkit.Chem.QED import qed

    # --- Loading report ---
    await flyte.report.replace.aio(
        _wrap_report("<h2>Computing Molecular Properties...</h2>"
                      "<p>Analyzing physicochemical descriptors for all molecules.</p>"),
        do_flush=True,
    )

    mol_dir = await molecule_dir.download()
    with open(os.path.join(mol_dir, "manifest.json")) as f:
        manifest = json.load(f)

    molecules_data = []
    lipinski_pass = 0

    for mol_info in manifest["molecules"]:
        mol = Chem.MolFromSmiles(mol_info["smiles"])
        if mol is None:
            continue

        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        hbd = Lipinski.NumHDonors(mol)
        hba = Lipinski.NumHAcceptors(mol)
        tpsa = Descriptors.TPSA(mol)
        rotatable = Lipinski.NumRotatableBonds(mol)
        formal_charge = Chem.GetFormalCharge(mol)
        num_rings = Lipinski.RingCount(mol)
        qed_score = qed(mol)

        # Lipinski Rule of Five
        lipinski = {
            "mw_ok": mw <= 500,
            "logp_ok": logp <= 5,
            "hbd_ok": hbd <= 5,
            "hba_ok": hba <= 10,
        }
        lipinski_all = all(lipinski.values())
        if lipinski_all:
            lipinski_pass += 1

        # Read image for data URI
        img_path = os.path.join(mol_dir, mol_info["image_file"])
        data_uri = ""
        if os.path.exists(img_path):
            with open(img_path, "rb") as img_f:
                b64 = base64.b64encode(img_f.read()).decode()
                data_uri = f"data:image/png;base64,{b64}"

        molecules_data.append({
            "name": mol_info["name"],
            "smiles": mol_info["smiles"],
            "mw": round(mw, 2),
            "logp": round(logp, 2),
            "hbd": hbd,
            "hba": hba,
            "tpsa": round(tpsa, 2),
            "rotatable_bonds": rotatable,
            "formal_charge": formal_charge,
            "num_rings": num_rings,
            "qed": round(qed_score, 4),
            "lipinski": lipinski,
            "lipinski_pass": lipinski_all,
            "image_data_uri": data_uri,
        })

    total = len(molecules_data)
    avg_mw = sum(m["mw"] for m in molecules_data) / total if total else 0
    avg_logp = sum(m["logp"] for m in molecules_data) / total if total else 0
    lipinski_rate = lipinski_pass / total * 100 if total else 0

    # ---- Build report ----
    html_parts = []

    # Header
    html_parts.append("<h2>Molecular Properties Analysis</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Total Molecules"),
        (f"{lipinski_rate:.0f}%", "Lipinski Pass Rate"),
        (f"{avg_mw:.1f}", "Avg. MW (Da)"),
        (f"{avg_logp:.2f}", "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Molecule gallery
    html_parts.append("<h3>Molecule Library</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in molecules_data:
        if m["image_data_uri"]:
            badge_class = "badge-success" if m["lipinski_pass"] else "badge-danger"
            badge_text = "Lipinski Pass" if m["lipinski_pass"] else "Lipinski Fail"
            html_parts.append(
                f'<div class="molecule-card" style="text-align:center;">'
                f'<img src="{m["image_data_uri"]}" style="width:160px;height:160px;object-fit:contain;"/>'
                f'<div style="font-weight:600;margin-top:6px;color:#0e4f6e;">{m["name"]}</div>'
                f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f}</div>'
                f'<div><span class="badge {badge_class}">{badge_text}</span></div>'
                f'</div>'
            )
    html_parts.append("</div>")

    # MW bar chart (horizontal, sorted)
    sorted_by_mw = sorted(molecules_data, key=lambda m: m["mw"], reverse=True)
    mw_labels = [m["name"] for m in sorted_by_mw]
    mw_vals = [m["mw"] for m in sorted_by_mw]
    mw_chart = _make_bar_chart(
        mw_labels, {"MW (Da)": mw_vals},
        title="Molecular Weight Distribution",
        horizontal=True,
        width=700, height=max(300, len(mw_labels) * 30 + 80),
        value_fmt=".1f",
    )
    html_parts.append("<h3>Molecular Weight</h3>")
    html_parts.append(f'<div class="chart-container">{mw_chart}</div>')

    # LogP vs MW scatter plot
    scatter_points = [
        {"x": m["mw"], "y": m["logp"], "label": m["name"]}
        for m in molecules_data
    ]
    scatter_chart = _make_scatter_plot(
        scatter_points,
        x_label="Molecular Weight (Da)",
        y_label="LogP",
        title="LogP vs. Molecular Weight (Lipinski Boundaries)",
        reference_lines=[
            {"axis": "x", "value": 500, "label": "MW = 500"},
            {"axis": "y", "value": 5, "label": "LogP = 5"},
        ],
        width=700,
        height=420,
    )
    html_parts.append("<h3>Lipinski Space</h3>")
    html_parts.append(f'<div class="chart-container">{scatter_chart}</div>')

    # Property heatmap (molecules x properties)
    prop_names = ["MW", "LogP", "HBD", "HBA", "TPSA", "Rot. Bonds"]
    # Normalize each property to 0-1 for heatmap
    raw_matrix = []
    for m in molecules_data:
        raw_matrix.append([m["mw"], m["logp"], m["hbd"], m["hba"], m["tpsa"], m["rotatable_bonds"]])

    # Normalize per column
    n_props = len(prop_names)
    col_min = [min(row[c] for row in raw_matrix) for c in range(n_props)]
    col_max = [max(row[c] for row in raw_matrix) for c in range(n_props)]
    norm_matrix = []
    for row in raw_matrix:
        norm_row = []
        for c in range(n_props):
            rng = col_max[c] - col_min[c]
            norm_row.append((row[c] - col_min[c]) / rng if rng else 0.5)
        norm_matrix.append(norm_row)

    heatmap_labels = [m["name"] for m in molecules_data]
    heatmap = _make_heatmap(
        norm_matrix, heatmap_labels, prop_names,
        title="Normalized Property Heatmap",
        color_scale="cyan",
        width=700,
        height=max(400, len(heatmap_labels) * 28 + 100),
    )
    html_parts.append("<h3>Property Heatmap</h3>")
    html_parts.append(f'<div class="chart-container">{heatmap}</div>')

    # Lipinski compliance table
    html_parts.append("<h3>Lipinski Rule of Five Compliance</h3>")
    html_parts.append("<table><tr><th>Molecule</th><th>MW &le; 500</th>"
                      "<th>LogP &le; 5</th><th>HBD &le; 5</th>"
                      "<th>HBA &le; 10</th><th>Overall</th></tr>")
    for m in molecules_data:
        lip = m["lipinski"]

        def _badge(ok):
            if ok:
                return '<span class="badge badge-success">Pass</span>'
            return '<span class="badge badge-danger">Fail</span>'

        overall_badge = _badge(m["lipinski_pass"])
        html_parts.append(
            f'<tr><td><strong>{m["name"]}</strong></td>'
            f'<td>{_badge(lip["mw_ok"])}</td>'
            f'<td>{_badge(lip["logp_ok"])}</td>'
            f'<td>{_badge(lip["hbd_ok"])}</td>'
            f'<td>{_badge(lip["hba_ok"])}</td>'
            f'<td>{overall_badge}</td></tr>'
        )
    html_parts.append("</table>")

    # QED bar chart
    sorted_by_qed = sorted(molecules_data, key=lambda m: m["qed"], reverse=True)
    qed_labels = [m["name"] for m in sorted_by_qed]
    qed_vals = [m["qed"] for m in sorted_by_qed]
    qed_chart = _make_bar_chart(
        qed_labels, {"QED Score": qed_vals},
        title="Drug-likeness (QED Score)",
        horizontal=True,
        width=700, height=max(300, len(qed_labels) * 30 + 80),
        value_fmt=".3f",
        colors=["#06d6a0"],
    )
    html_parts.append("<h3>Drug-likeness (QED)</h3>")
    html_parts.append(f'<div class="chart-container">{qed_chart}</div>')

    # Flush full report
    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # Return properties as JSON (strip image data URIs to reduce size)
    output = {
        "total": total,
        "lipinski_pass_count": lipinski_pass,
        "lipinski_pass_rate": round(lipinski_rate, 2),
        "avg_mw": round(avg_mw, 2),
        "avg_logp": round(avg_logp, 2),
        "molecules": [
            {k: v for k, v in m.items() if k != "image_data_uri"}
            for m in molecules_data
        ],
    }
    return json.dumps(output)

# ------------------------------------------------------------------
# Task 3: Screen candidates against target profile
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def screen_candidates(
    properties_json: str,
    target_profile: str = "",
) -> str:
    """Screen molecules against a target drug profile and rank candidates.

    Scores each molecule on how well it matches the target profile, computes
    pairwise Tanimoto similarity, and produces a ranked list.

    Args:
        properties_json: JSON from compute_properties.
        target_profile: JSON string with desired property ranges
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5.0]}).

    Returns:
        JSON string with ranked_molecules, similarity_matrix, similarity_labels,
        funnel, and target_profile. Pass the full return value verbatim to
        generate_report along with molecule_dir and properties_json.
    """
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Screening Candidates...</h2>"
                      "<p>Evaluating molecules against the target drug profile.</p>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    molecules = props["molecules"]

    # Default target profile
    if target_profile.strip():
        profile = json.loads(target_profile)
    else:
        profile = {
            "mw": [150, 500],
            "logp": [-0.5, 5.0],
            "hbd": [0, 5],
            "hba": [0, 10],
            "tpsa": [20, 140],
        }

    # --- Screening ---
    funnel_total = len(molecules)
    pass_mw = 0
    pass_logp = 0
    pass_lipinski = 0
    final_candidates = 0

    scored = []
    for m in molecules:
        score = 0
        max_score = 0
        criteria = {}

        # Check each profile criterion
        checks = [
            ("mw", m["mw"]),
            ("logp", m["logp"]),
            ("hbd", m["hbd"]),
            ("hba", m["hba"]),
            ("tpsa", m["tpsa"]),
        ]

        for key, val in checks:
            if key in profile:
                lo, hi = profile[key]
                max_score += 1
                in_range = lo <= val <= hi
                criteria[key] = in_range
                if in_range:
                    score += 1
                    # Bonus: closer to midpoint = higher score
                    mid = (lo + hi) / 2
                    rng = (hi - lo) / 2
                    dist = abs(val - mid) / rng if rng else 0
                    score += max(0, 0.5 * (1 - dist))

        # QED bonus
        score += m["qed"] * 2
        max_score += 2

        # Lipinski bonus
        if m["lipinski_pass"]:
            score += 1
        max_score += 1

        normalized_score = score / max_score if max_score else 0

        # Funnel tracking — cascading filter (each stage requires passing the previous)
        mw_ok = criteria.get("mw", True)
        logp_ok = criteria.get("logp", True)
        if mw_ok:
            pass_mw += 1
            if logp_ok:
                pass_logp += 1
                if m["lipinski_pass"]:
                    pass_lipinski += 1
                    if all(criteria.values()):
                        final_candidates += 1

        scored.append({
            **m,
            "screening_score": round(normalized_score, 4),
            "criteria_met": criteria,
            "all_criteria_met": all(criteria.values()),
        })

    # Sort by score descending
    scored.sort(key=lambda m: m["screening_score"], reverse=True)

    # --- Tanimoto similarity matrix ---
    fps = []
    valid_names = []
    for m in scored:
        mol = Chem.MolFromSmiles(m["smiles"])
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            fps.append(fp)
            valid_names.append(m["name"])

    similarity_matrix = []
    for i in range(len(fps)):
        row = []
        for j in range(len(fps)):
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            row.append(round(sim, 3))
        similarity_matrix.append(row)

    # ---- Build report ----
    html_parts = []
    html_parts.append("<h2>Candidate Screening Results</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(funnel_total), "Total Screened"),
        (str(pass_lipinski), "Lipinski Passes"),
        (str(final_candidates), "All Criteria Met"),
        (f"{scored[0]['screening_score']:.3f}" if scored else "N/A", "Top Score"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Screening funnel
    funnel_stages = [
        {"label": "Total Molecules", "count": funnel_total, "total": funnel_total},
        {"label": "Pass MW Filter", "count": pass_mw, "total": funnel_total},
        {"label": "Pass LogP Filter", "count": pass_logp, "total": funnel_total},
        {"label": "Lipinski Compliant", "count": pass_lipinski, "total": funnel_total},
        {"label": "All Criteria Met", "count": final_candidates, "total": funnel_total},
    ]
    funnel_svg = _make_funnel(
        funnel_stages,
        title="Screening Funnel",
        width=600,
        height=380,
    )
    html_parts.append("<h3>Screening Funnel</h3>")
    html_parts.append(f'<div class="chart-container" style="text-align:center;">{funnel_svg}</div>')

    # Ranked candidates table
    html_parts.append("<h3>Ranked Candidates</h3>")
    html_parts.append(
        "<table><tr><th>Rank</th><th>Molecule</th><th>Score</th>"
        "<th>MW</th><th>LogP</th><th>QED</th><th>Lipinski</th><th>All Criteria</th></tr>"
    )
    for rank, m in enumerate(scored, 1):
        lip_badge = ('<span class="badge badge-success">Pass</span>'
                     if m["lipinski_pass"]
                     else '<span class="badge badge-danger">Fail</span>')
        crit_badge = ('<span class="badge badge-success">Pass</span>'
                      if m["all_criteria_met"]
                      else '<span class="badge badge-danger">Fail</span>')
        # Highlight top 3
        row_style = ' style="background:#ecfeff;font-weight:600;"' if rank <= 3 else ""
        html_parts.append(
            f"<tr{row_style}><td>{rank}</td><td>{m['name']}</td>"
            f"<td>{m['screening_score']:.3f}</td>"
            f"<td>{m['mw']:.1f}</td><td>{m['logp']:.2f}</td>"
            f"<td>{m['qed']:.3f}</td><td>{lip_badge}</td><td>{crit_badge}</td></tr>"
        )
    html_parts.append("</table>")

    # Top 5 candidate cards with structures
    html_parts.append("<h3>Top 5 Candidates</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in scored[:5]:
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(250, 250)) if mol else ""
        badge_class = "badge-success" if m["all_criteria_met"] else "badge-info"
        badge_text = "All Criteria Met" if m["all_criteria_met"] else "Partial Match"
        html_parts.append(
            f'<div class="molecule-card" style="text-align:center;">'
            f'<img src="{img_uri}" style="width:140px;height:140px;object-fit:contain;"/>'
            f'<div style="font-weight:700;margin-top:6px;color:#0e4f6e;font-size:1.05em;">{m["name"]}</div>'
            f'<div style="font-size:0.85em;color:#155e75;margin:4px 0;">Score: {m["screening_score"]:.3f}</div>'
            f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f} | QED: {m["qed"]:.3f}</div>'
            f'<div style="margin-top:4px;"><span class="badge {badge_class}">{badge_text}</span></div>'
            f'</div>'
        )
    html_parts.append("</div>")

    # Tanimoto similarity heatmap
    if similarity_matrix:
        sim_heatmap = _make_heatmap(
            similarity_matrix, valid_names, valid_names,
            title="Pairwise Tanimoto Similarity (Morgan Fingerprints)",
            color_scale="cyan",
            width=700,
            height=max(500, len(valid_names) * 32 + 100),
        )
        html_parts.append("<h3>Chemical Similarity</h3>")
        html_parts.append(f'<div class="chart-container">{sim_heatmap}</div>')

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    output = {
        "ranked_molecules": scored,
        "similarity_matrix": similarity_matrix,
        "similarity_labels": valid_names,
        "funnel": funnel_stages,
        "target_profile": profile,
    }
    return json.dumps(output)

def _parse_screening_json(screening_json: str) -> dict:
    """Parse screening JSON from screen_candidates, with safe defaults.

    The agent must pass the exact tool return value. Partial or hand-built JSON
    is tolerated for optional similarity fields only.
    """
    screening = json.loads(screening_json)
    if "ranked_molecules" not in screening:
        raise ValueError(
            "screening_json must be the exact JSON string returned by "
            "screen_candidates (missing 'ranked_molecules'). Do not construct, "
            "truncate, or summarize tool output."
        )
    screening.setdefault("similarity_matrix", [])
    screening.setdefault("similarity_labels", [])
    return screening

# ------------------------------------------------------------------
# Task 4: Generate final comprehensive report
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def generate_report(
    molecule_dir: flyte.io.Dir,
    properties_json: str,
    screening_json: str,
) -> str:
    """Generate a comprehensive drug screening report.

    Produces an executive summary, top candidate spotlight cards, property
    distributions, chemical diversity analysis, and final recommendation.

    Args:
        molecule_dir: Directory from load_molecules.
        properties_json: JSON from compute_properties.
        screening_json: Exact verbatim JSON string returned by screen_candidates
            (must include ranked_molecules, similarity_matrix, similarity_labels).
            Do not construct or summarize this payload yourself.

    Returns:
        JSON summary with total_screened, lipinski_passes, all_criteria_met,
        top_candidate, top_score, and top_3 ranked molecules.
    """
    from rdkit import Chem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Generating Final Report...</h2>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    screening = _parse_screening_json(screening_json)
    ranked = screening["ranked_molecules"]
    sim_matrix = screening["similarity_matrix"]
    sim_labels = screening["similarity_labels"]

    total = props["total"]
    lipinski_pass = props["lipinski_pass_count"]
    all_criteria = sum(1 for m in ranked if m["all_criteria_met"])
    top = ranked[0] if ranked else None

    html_parts = []

    # --- Executive Summary ---
    html_parts.append("<h2>Drug Molecule Screening Report</h2>")
    top_name = top["name"] if top else "N/A"
    top_score = f'{top["screening_score"]:.3f}' if top else "N/A"
    html_parts.append(
        f'<div class="card">'
        f'<h3 style="margin-top:0;color:#0e4f6e;">Executive Summary</h3>'
        f'<p style="font-size:1.05em;">'
        f'<strong>{total}</strong> molecules were screened against the target drug profile. '
        f'<strong>{lipinski_pass}</strong> passed Lipinski\'s Rule of Five, and '
        f'<strong>{all_criteria}</strong> met all screening criteria. '
        f'The top candidate is <strong style="color:#0891b2;">{top_name}</strong> '
        f'with a screening score of <strong>{top_score}</strong>.</p>'
        f'</div>'
    )

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Molecules Screened"),
        (str(lipinski_pass), "Lipinski Passes"),
        (str(all_criteria), "All Criteria Met"),
        (top_score, "Top Score"),
        (f'{props["avg_mw"]:.0f} Da', "Avg. Molecular Weight"),
        (f'{props["avg_logp"]:.2f}', "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # --- Top 3 Candidate Spotlights ---
    html_parts.append("<h2>Top Candidate Spotlights</h2>")

    for rank, m in enumerate(ranked[:3], 1):
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(300, 300)) if mol else ""

        medal = ["gold", "silver", "#cd7f32"][rank - 1]
        medal_emoji = ["1st", "2nd", "3rd"][rank - 1]

        lip_badges = ""
        for rule, key in [("MW", "mw_ok"), ("LogP", "logp_ok"),
                          ("HBD", "hbd_ok"), ("HBA", "hba_ok")]:
            ok = m["lipinski"].get(key, False)
            cls = "badge-success" if ok else "badge-danger"
            lip_badges += f'<span class="badge {cls}" style="margin:2px;">{rule}</span> '

        html_parts.append(
            f'<div class="molecule-card" style="display:flex;gap:20px;align-items:flex-start;flex-wrap:wrap;">'
            f'<div style="text-align:center;min-width:180px;">'
            f'<div style="font-size:1.6em;font-weight:800;color:{medal};">{medal_emoji}</div>'
            f'<img src="{img_uri}" style="width:200px;height:200px;object-fit:contain;border-radius:8px;'
            f'border:2px solid #a5f3fc;"/>'
            f'<div style="font-weight:700;font-size:1.1em;color:#0e4f6e;margin-top:8px;">{m["name"]}</div>'
            f'</div>'
            f'<div style="flex:1;min-width:280px;">'
            f'<table style="margin:0;">'
            f'<tr><td><strong>SMILES</strong></td><td style="font-family:monospace;font-size:0.8em;word-break:break-all;">{m["smiles"]}</td></tr>'
            f'<tr><td><strong>Screening Score</strong></td><td style="font-weight:700;color:#0891b2;font-size:1.1em;">{m["screening_score"]:.3f}</td></tr>'
            f'<tr><td><strong>Molecular Weight</strong></td><td>{m["mw"]:.1f} Da</td></tr>'
            f'<tr><td><strong>LogP</strong></td><td>{m["logp"]:.2f}</td></tr>'
            f'<tr><td><strong>H-Bond Donors</strong></td><td>{m["hbd"]}</td></tr>'
            f'<tr><td><strong>H-Bond Acceptors</strong></td><td>{m["hba"]}</td></tr>'
            f'<tr><td><strong>TPSA</strong></td><td>{m["tpsa"]:.1f} A&sup2;</td></tr>'
            f'<tr><td><strong>Rotatable Bonds</strong></td><td>{m["rotatable_bonds"]}</td></tr>'
            f'<tr><td><strong>QED</strong></td><td>{m["qed"]:.4f}</td></tr>'
            f'<tr><td><strong>Lipinski Compliance</strong></td><td>{lip_badges}</td></tr>'
            f'</table>'
            f'</div>'
            f'</div>'
        )

    # --- Property Distribution (box-plot style as bars with min/max/median) ---
    html_parts.append("<h2>Property Distributions</h2>")

    prop_keys = [("mw", "Molecular Weight (Da)"), ("logp", "LogP"),
                 ("tpsa", "TPSA"), ("qed", "QED Score")]
    for key, label in prop_keys:
        vals = sorted([m[key] for m in ranked])
        n = len(vals)
        if n == 0:
            continue
        v_min = vals[0]
        v_max = vals[-1]
        median = vals[n // 2] if n % 2 == 1 else (vals[n // 2 - 1] + vals[n // 2]) / 2
        q1 = vals[n // 4] if n >= 4 else v_min
        q3 = vals[3 * n // 4] if n >= 4 else v_max

        # Simple horizontal box-plot as SVG
        box_w = 500
        box_h = 50
        margin_l = 10
        v_range = v_max - v_min or 1

        def sx(v):
            return margin_l + ((v - v_min) / v_range) * (box_w - 2 * margin_l)

        box_svg = (
            f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {box_w} {box_h}" '
            f'style="width:100%;max-width:{box_w}px;height:auto;">'
            f'<rect width="{box_w}" height="{box_h}" fill="#fff" rx="4"/>'
            # Whisker line
            f'<line x1="{sx(v_min):.1f}" y1="25" x2="{sx(v_max):.1f}" y2="25" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Min whisker
            f'<line x1="{sx(v_min):.1f}" y1="18" x2="{sx(v_min):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Max whisker
            f'<line x1="{sx(v_max):.1f}" y1="18" x2="{sx(v_max):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # IQR box
            f'<rect x="{sx(q1):.1f}" y="14" width="{sx(q3) - sx(q1):.1f}" height="22" '
            f'fill="#a5f3fc" stroke="#0891b2" stroke-width="1.5" rx="3"/>'
            # Median line
            f'<line x1="{sx(median):.1f}" y1="12" x2="{sx(median):.1f}" y2="38" '
            f'stroke="#0e4f6e" stroke-width="2"/>'
            # Labels
            f'<text x="{sx(v_min):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_min:.1f}</text>'
            f'<text x="{sx(median):.1f}" y="10" text-anchor="middle" font-size="9" fill="#0e4f6e" font-weight="600">{median:.1f}</text>'
            f'<text x="{sx(v_max):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_max:.1f}</text>'
            f'</svg>'
        )
        html_parts.append(
            f'<div style="margin:8px 0;"><strong style="color:#155e75;">{label}</strong>'
            f'<div class="chart-container" style="padding:8px;">{box_svg}</div></div>'
        )

    # --- Chemical Diversity ---
    html_parts.append("<h2>Chemical Diversity Analysis</h2>")

    if sim_matrix and len(sim_matrix) > 1:
        # Compute average pairwise similarity (off-diagonal)
        n_mols = len(sim_matrix)
        off_diag = []
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                off_diag.append(sim_matrix[i][j])

        avg_sim = sum(off_diag) / len(off_diag) if off_diag else 0
        max_sim = max(off_diag) if off_diag else 0
        min_sim = min(off_diag) if off_diag else 0

        # Find most similar pair
        best_i, best_j = 0, 1
        best_val = 0
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                if sim_matrix[i][j] > best_val:
                    best_val = sim_matrix[i][j]
                    best_i, best_j = i, j

        html_parts.append('<div class="stat-grid">')
        html_parts.append(
            f'<div class="stat"><div class="value">{avg_sim:.3f}</div>'
            f'<div class="label">Avg. Pairwise Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{min_sim:.3f}</div>'
            f'<div class="label">Min Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{max_sim:.3f}</div>'
            f'<div class="label">Max Similarity</div></div>'
        )
        html_parts.append("</div>")

        diversity_text = "highly diverse" if avg_sim < 0.3 else "moderately diverse" if avg_sim < 0.5 else "relatively similar"
        html_parts.append(
            f'<div class="note">'
            f'The library is <strong>{diversity_text}</strong> (avg. Tanimoto = {avg_sim:.3f}). '
            f'The most similar pair is <strong>{sim_labels[best_i]}</strong> and '
            f'<strong>{sim_labels[best_j]}</strong> (similarity = {best_val:.3f}).</div>'
        )

    # --- Recommendation ---
    html_parts.append("<h2>Recommendation</h2>")
    if top:
        html_parts.append(
            f'<div class="card">'
            f'<h3 style="margin-top:0;color:#0891b2;">Top Candidate: {top["name"]}</h3>'
            f'<p>Based on the virtual screening analysis, <strong>{top["name"]}</strong> '
            f'achieved the highest composite screening score of <strong>{top["screening_score"]:.3f}</strong>. '
        )

        reasons = []
        if top["lipinski_pass"]:
            reasons.append("full Lipinski Rule of Five compliance")
        if top["qed"] > 0.5:
            reasons.append(f"high drug-likeness (QED = {top['qed']:.3f})")
        if top.get("all_criteria_met"):
            reasons.append("all target profile criteria met")
        if top["mw"] <= 500:
            reasons.append(f"favorable molecular weight ({top['mw']:.1f} Da)")

        if reasons:
            html_parts.append(
                f'This candidate stands out due to: {", ".join(reasons)}.</p>'
            )
        else:
            html_parts.append("</p>")

        # Runner-up mentions
        if len(ranked) >= 2:
            html_parts.append(
                f'<p style="font-size:0.9em;color:#6c757d;">Runner-up candidates: '
            )
            runners = []
            for m in ranked[1:4]:
                runners.append(f'{m["name"]} (score: {m["screening_score"]:.3f})')
            html_parts.append(", ".join(runners) + ".</p>")

        html_parts.append("</div>")

    # Final note
    html_parts.append(
        '<div class="note">'
        "This is a virtual screening analysis. All candidates should undergo "
        "further computational validation (molecular dynamics, docking) and "
        "experimental testing before advancing to clinical trials.</div>"
    )

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # JSON summary
    summary = {
        "total_screened": total,
        "lipinski_passes": lipinski_pass,
        "all_criteria_met": all_criteria,
        "top_candidate": top["name"] if top else None,
        "top_score": top["screening_score"] if top else None,
        "top_3": [
            {"name": m["name"], "score": m["screening_score"]}
            for m in ranked[:3]
        ],
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Agent
# ------------------------------------------------------------------

# {{docs-fragment agent}}
SCREENING_AGENT_INSTRUCTIONS = """\
You are a medicinal chemistry screening strategist. You orchestrate a virtual \
screening pipeline using durable Flyte tools. You NEVER invent molecular \
properties — only RDKit tools compute them.

Workflow:
1. If target_profile is not provided in the user message, derive a JSON \
target_profile from the therapeutic brief. Valid keys: mw, logp, hbd, hba, tpsa \
(each [min, max]). Ground choices in oral bioavailability / kinase / CNS rules \
as appropriate to the brief.
2. First pass (always): load_molecules → compute_properties → \
screen_candidates → generate_report. Pass tool outputs between steps exactly \
(molecule_dir from load_molecules into compute_properties and generate_report; \
properties_json from compute_properties into screen_candidates and \
generate_report; screening_json must be the complete, unmodified string \
returned by screen_candidates — never rebuild or summarize JSON yourself).
3. Read the JSON summary returned by generate_report. Reflect:
   - If all_criteria_met == 0: relax exactly ONE profile bound by ~10–20% \
and re-run screen_candidates then generate_report only, reusing the same \
molecule_dir and properties_json from the first pass.
   - If all molecules pass but diversity is a stated goal: note high similarity \
in your summary; do not re-run unless brief asks for stricter filters.
   - Maximum ONE rescreen iteration.
4. Finish with plain text: top candidate, rationale tied to computed metrics \
from the tool JSON, funnel interpretation, and suggested next steps (docking, \
ADMET lab tests).

If the user supplies an explicit target_profile JSON, use it as-is.

Do NOT ask the user for SMILES or molecule lists when molecules_json is empty — \
the default library is loaded automatically.
"""

screening_agent = Agent(
    name="drug-screening-agent",
    instructions=SCREENING_AGENT_INSTRUCTIONS,
    model=MODEL,
    tools=[
        load_molecules,
        compute_properties,
        screen_candidates,
        generate_report,
    ],
    max_turns=12,
)
# {{/docs-fragment agent}}

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def pipeline(
    brief: str = "Screen the default drug library for orally bioavailable small molecules.",
    molecules_json: str = "",
    target_profile: str = "",
) -> str:
    """Agentic virtual drug molecule screening pipeline.

    A medicinal-chemistry agent interprets the screening brief, derives or
    applies a target profile, orchestrates the RDKit screening stages, and
    optionally re-screens when funnel results are too narrow.

    Args:
        brief: Natural-language therapeutic goal (e.g. oral kinase inhibitors,
            CNS-penetrant small molecules).
        molecules_json: JSON mapping molecule names to SMILES strings.
            Defaults to a curated library of ~15 well-known drugs.
        target_profile: Optional JSON with desired property ranges that
            overrides agent-derived criteria
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5]}).

    Returns:
        Agent summary with screening rationale and key results.
    """
    prompt_parts = [
        f"Screening brief: {brief}",
        'Use molecules_json="" for the built-in default library unless provided below.',
        "Compose the four stage tools in order: load_molecules → compute_properties "
        "→ screen_candidates → generate_report. Pass each tool's full return value "
        "verbatim to the next step (especially screening_json). Re-run "
        "screen_candidates and generate_report at most once if the funnel is too narrow.",
    ]
    if molecules_json.strip():
        prompt_parts.append(f"molecules_json: {molecules_json}")
    if target_profile.strip():
        prompt_parts.append(f"Use this target_profile exactly: {target_profile}")

    result = await screening_agent.run.aio("\n".join(prompt_parts))
    return result.summary or result.error or ""

# {{/docs-fragment pipeline}}

# ------------------------------------------------------------------
# Rescreen demo — tight profile + explicit rescreen instructions
# ------------------------------------------------------------------

# Initial profile is deliberately strict (narrow MW + low LogP cap) so
# all_criteria_met is typically 0 on the default library; the brief then
# forces a single rescreen with a widened LogP window.
RESCREEN_DEMO_TARGET_PROFILE = (
    '{"mw": [150, 200], "logp": [-0.5, 1.0], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_TARGET_PROFILE_RESCREEN = (
    '{"mw": [150, 200], "logp": [-0.5, 3.5], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_BRIEF = f"""\
Two-round agentic screening demo on the default library.

**Round 1 (strict profile):** load_molecules → compute_properties → \
screen_candidates → generate_report using the initial target_profile exactly.

**Round 2 (required — do not skip):** call screen_candidates then generate_report \
again, reusing the same molecule_dir and properties_json from round 1, with this \
relaxed target_profile (wider LogP window only): \
{RESCREEN_DEMO_TARGET_PROFILE_RESCREEN}

Pass every tool return value verbatim to the next step. After both rounds, \
summarize how the funnel and top candidates changed between round 1 and round 2."""

# {{docs-fragment rescreen_demo}}
@env.task(report=True)
async def rescreen_demo() -> str:
    """Example run with a two-round execution graph (rescreen).

    Round 1 uses a strict CNS-like profile; round 2 always re-runs
    screen_candidates and generate_report with a widened LogP window,
    reusing cached molecule_dir and properties_json.
    """
    return await pipeline(
        brief=RESCREEN_DEMO_BRIEF,
        target_profile=RESCREEN_DEMO_TARGET_PROFILE,
    )

# {{/docs-fragment rescreen_demo}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/drug_molecule_screening/drug_molecule_screening.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.4",
#    "litellm",
#    "rdkit",
#    "numpy",
#    "scikit-learn",
#    "pillow",
# ]
# ///
```

## Define the screening agent

The agent receives a natural-language brief and composes four stage tools in order. Each tool is a durable Flyte task with its own `report=True` surface in the Flyte UI.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.4",
#    "litellm",
#    "rdkit",
#    "numpy",
#    "scikit-learn",
#    "pillow",
# ]
# main = "pipeline"
# params = ""
# ///
"""Virtual drug molecule screening — compute properties, apply Lipinski filters, rank candidates."""

import base64
import io
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report
from flyte.ai.agents import Agent, tool

MODEL = os.getenv("DRUG_SCREENING_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="drug-molecule-screening", pre=True).with_apt_packages(
    "libxrender1", "libxext6", "libexpat1",
)

env = flyte.TaskEnvironment(
    name="drug-molecule-screening",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default molecule library — real SMILES for well-known drugs
# ------------------------------------------------------------------

DEFAULT_MOLECULES = {
    "Aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
    "Ibuprofen": "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "Penicillin G": "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C",
    "Metformin": "CN(C)C(=N)NC(=N)N",
    "Paracetamol": "CC(=O)NC1=CC=C(C=C1)O",
    "Diazepam": "ClC1=CC2=C(C=C1)N(C(=O)CN=C2C3=CC=CC=C3)C",
    "Omeprazole": "CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=CC=CC=C3N2",
    "Atorvastatin": "CC(C)C1=C(C(=C(N1CCC(CC(CC(=O)O)O)O)C2=CC=C(C=C2)F)C3=CC=CC=C3)C(=O)NC4=CC=CC=C4",
    "Methotrexate": "CN(CC1=CN=C2N=C(N=C(N)C2=N1)N)C3=CC=C(C=C3)C(=O)NC(CCC(=O)O)C(=O)O",
    "Doxorubicin": "CC1C(C(CC(O1)OC2CC(CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
    "Tamoxifen": "CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3",
    "Lopinavir": "CC1=C(C(=CC=C1)C)OCC(=O)NC(CC2=CC=CC=C2)C(CC(CC3=CC=CC=C3)NC(=O)C(C(C)C)N4CCCNC4=O)O",
    "Remdesivir": "CCC(CC)COC(=O)C(C)NP(=O)(OCC1C(C(C(O1)C2=CC=C3N2N=CN=C3N)O)O)OC4=CC=CC=C4",
    "Erlotinib": "COCCOC1=CC2=C(C=C1OCCOC)C(=NC=N2)NC3=CC=CC(=C3)C#C",
}

# ------------------------------------------------------------------
# Report styling — pharma blue/cyan theme
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #0e4f6e; border-bottom: 2px solid #0891b2; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #155e75; margin-top: 20px; }
  .report .card { background: #ecfeff; border: 1px solid #a5f3fc; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #cffafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0e4f6e; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0e4f6e; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #cffafe; }
  .report tr:nth-child(even) { background: #ecfeff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #cffafe; color: #155e75; }
  .report .chart-container { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #ecfeff; border-left: 4px solid #0891b2; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .molecule-card { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .molecule-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 12px; margin: 16px 0; }
  .report .funnel { text-align: center; margin: 24px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _mol_to_data_uri(mol, size: tuple[int, int] = (300, 300)) -> str:
    """Convert an RDKit molecule to a PNG base64 data URI."""
    from rdkit.Chem import Draw

    img = Draw.MolToImage(mol, size=size)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f"data:image/png;base64,{b64}"

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 340,
    y_max_cap: float | None = None,
    horizontal: bool = False,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels.
        series: Dict mapping series name to list of values.
        title: Chart title.
        colors: Colors for each series.
        width/height: SVG dimensions.
        y_max_cap: Cap the y-axis at this value.
        horizontal: If True, draw horizontal bars.
        value_fmt: Format string for value labels.

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#a5f3fc", "#155e75"]
    colors = colors or default_colors

    if horizontal:
        return _make_horizontal_bar_chart(labels, series, title, colors, width, height, value_fmt)

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_fmt}}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_fmt}}</text>'
            )
        # Truncate long labels
        disp_label = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 16}" '
            f'text-anchor="middle" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {gx + n_series * bar_width / 2:.1f}, {mt + ch + 16})">'
            f'{disp_label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 100
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 100}" y="{mt + ch + 40}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 51}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_horizontal_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 400,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG horizontal bar chart (sorted)."""
    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0"]
    colors = colors or default_colors

    n = len(labels)
    row_height = max(22, min(35, (height - 80) // max(n, 1)))
    actual_height = max(height, 80 + n * row_height)
    ml, mr, mt, mb = 120, 60, 40, 20
    cw = width - ml - mr
    ch = actual_height - mt - mb

    # Use first series
    first_key = list(series.keys())[0]
    vals = series[first_key]
    x_max = max(vals) * 1.15 if vals else 1

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {actual_height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{actual_height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    bar_h = row_height * 0.65
    for i, (label, val) in enumerate(zip(labels, vals)):
        y = mt + i * row_height
        bw = (val / x_max) * cw if x_max else 0
        color = colors[i % len(colors)]
        # Label
        disp = label if len(label) <= 14 else label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + bar_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#1a1a2e">{disp}</text>'
        )
        # Bar
        svg.append(
            f'<rect x="{ml}" y="{y:.1f}" width="{bw:.1f}" height="{bar_h:.1f}" '
            f'fill="{color}" rx="3"/>'
        )
        # Value
        svg.append(
            f'<text x="{ml + bw + 6:.1f}" y="{y + bar_h / 2 + 4:.1f}" '
            f'font-size="11" fill="#0e4f6e" font-weight="600">{val:{value_fmt}}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    color_scale: str = "cyan",
    width: int = 700,
    height: int = 500,
    value_fmt: str = ".2f",
) -> str:
    """Generate an SVG heatmap.

    Args:
        matrix: 2D list of values (rows x cols).
        row_labels: Labels for rows.
        col_labels: Labels for columns.
        title: Chart title.
        color_scale: Color scheme ("cyan", "red", "green").
        width/height: SVG dimensions.
        value_fmt: Format string for cell values.

    Returns:
        SVG string.
    """
    if not matrix or not matrix[0]:
        return ""

    n_rows = len(matrix)
    n_cols = len(matrix[0])

    ml, mr, mt, mb = 110, 20, 70, 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    # Flatten to find range
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    def color_for(v):
        t = (v - v_min) / v_range
        if color_scale == "cyan":
            # White to deep teal
            r = int(255 - t * (255 - 14))
            g = int(255 - t * (255 - 79))
            b = int(255 - t * (255 - 110))
        elif color_scale == "red":
            r = int(255 - t * 50)
            g = int(255 - t * 200)
            b = int(255 - t * 200)
        else:  # green
            r = int(255 - t * 200)
            g = int(255 - t * 50)
            b = int(255 - t * 200)
        return f"rgb({r},{g},{b})"

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Column labels (rotated)
    for ci, label in enumerate(col_labels):
        x = ml + ci * cell_w + cell_w / 2
        disp = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{x:.1f}" y="{mt - 8}" text-anchor="end" font-size="10" '
            f'fill="#1a1a2e" transform="rotate(-45, {x:.1f}, {mt - 8})">{disp}</text>'
        )

    # Row labels + cells
    for ri, (row_label, row_vals) in enumerate(zip(row_labels, matrix)):
        y = mt + ri * cell_h
        disp = row_label if len(row_label) <= 14 else row_label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + cell_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#1a1a2e">{disp}</text>'
        )
        for ci, val in enumerate(row_vals):
            x = ml + ci * cell_w
            fill = color_for(val)
            svg.append(
                f'<rect x="{x:.1f}" y="{y:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{fill}" stroke="#fff" stroke-width="1"/>'
            )
            # Text color: dark on light, light on dark
            t = (val - v_min) / v_range
            txt_color = "#fff" if t > 0.55 else "#1a1a2e"
            # Only show text if cells are large enough
            if cell_w > 30 and cell_h > 18:
                svg.append(
                    f'<text x="{x + cell_w / 2:.1f}" y="{y + cell_h / 2 + 4:.1f}" '
                    f'text-anchor="middle" font-size="9" fill="{txt_color}">'
                    f'{val:{value_fmt}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_scatter_plot(
    points: list[dict],
    x_label: str = "MW",
    y_label: str = "LogP",
    title: str = "",
    reference_lines: list[dict] | None = None,
    width: int = 700,
    height: int = 400,
) -> str:
    """Generate an SVG scatter plot.

    Args:
        points: List of dicts with "x", "y", "label" keys.
        x_label/y_label: Axis labels.
        title: Chart title.
        reference_lines: List of dicts with "axis" ("x"/"y"), "value", "label".
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not points:
        return ""

    ml, mr, mt, mb = 60, 30, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [p["x"] for p in points]
    y_vals = [p["y"] for p in points]
    x_min, x_max = min(x_vals) * 0.9, max(x_vals) * 1.1
    y_min, y_max = min(y_vals) - 1, max(y_vals) + 1

    # Extend ranges to include reference lines
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                x_max = max(x_max, rl["value"] * 1.1)
            else:
                y_max = max(y_max, rl["value"] * 1.1)

    x_range = x_max - x_min or 1
    y_range = y_max - y_min or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid
    for i in range(6):
        y_tick = y_min + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.1f}</text>'
        )

    for i in range(6):
        x_tick = x_min + x_range * i / 5
        px = sx(x_tick)
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
            f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Reference lines (Lipinski boundaries)
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                px = sx(rl["value"])
                svg.append(
                    f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{px + 4:.1f}" y="{mt + 14}" font-size="10" '
                    f'fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )
            else:
                py = sy(rl["value"])
                svg.append(
                    f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{ml + cw - 4:.1f}" y="{py - 6:.1f}" text-anchor="end" '
                    f'font-size="10" fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )

    # Drug-like zone shading (MW<=500 and LogP<=5 quadrant)
    if reference_lines:
        mw_line = next((rl for rl in reference_lines if rl["axis"] == "x"), None)
        logp_line = next((rl for rl in reference_lines if rl["axis"] == "y"), None)
        if mw_line and logp_line:
            zx1 = sx(x_min)
            zx2 = sx(min(mw_line["value"], x_max))
            zy1 = sy(min(logp_line["value"], y_max))
            zy2 = sy(y_min)
            svg.append(
                f'<rect x="{zx1:.1f}" y="{zy1:.1f}" '
                f'width="{zx2 - zx1:.1f}" height="{zy2 - zy1:.1f}" '
                f'fill="#0891b2" opacity="0.06" rx="4"/>'
            )
            svg.append(
                f'<text x="{zx1 + 8:.1f}" y="{zy2 - 8:.1f}" font-size="11" '
                f'fill="#0891b2" font-weight="600" opacity="0.6">Drug-like Zone</text>'
            )

    # Points
    point_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#155e75", "#0284c7",
                    "#059669", "#0d9488", "#0369a1", "#047857", "#115e59",
                    "#0c4a6e", "#064e3b", "#1e3a5f", "#134e4a", "#075985"]
    for i, pt in enumerate(points):
        px, py = sx(pt["x"]), sy(pt["y"])
        color = point_colors[i % len(point_colors)]
        svg.append(
            f'<circle cx="{px:.1f}" cy="{py:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5" opacity="0.85"/>'
        )
        # Label offset to avoid overlap
        offset_x = 8
        offset_y = -8 if i % 2 == 0 else 14
        label = pt["label"] if len(pt["label"]) <= 12 else pt["label"][:10] + ".."
        svg.append(
            f'<text x="{px + offset_x:.1f}" y="{py + offset_y:.1f}" '
            f'font-size="9" fill="#1a1a2e">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Axis labels
    if x_label:
        svg.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        svg.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_funnel(
    stages: list[dict],
    title: str = "",
    width: int = 600,
    height: int = 400,
) -> str:
    """Generate an SVG funnel visualization.

    Args:
        stages: List of dicts with "label", "count", "total" keys.
        title: Chart title.
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not stages:
        return ""

    n = len(stages)
    mt = 50
    mb = 20
    available_h = height - mt - mb
    stage_h = available_h / n
    cx = width / 2

    # Color gradient from light cyan to deep teal
    colors = []
    for i in range(n):
        t = i / max(n - 1, 1)
        r = int(207 - t * (207 - 14))
        g = int(250 - t * (250 - 79))
        b = int(254 - t * (254 - 110))
        colors.append(f"rgb({r},{g},{b})")

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{cx}" y="28" text-anchor="middle" '
            f'font-size="16" font-weight="700" fill="#0e4f6e">{title}</text>'
        )

    max_count = stages[0]["count"] if stages else 1
    max_width = width * 0.75

    for i, stage in enumerate(stages):
        y_top = mt + i * stage_h
        y_bot = y_top + stage_h

        # Width proportional to count
        w_top = max_width * (stage["count"] / max_count) if i == 0 else prev_w_bot
        if i < n - 1:
            w_bot = max_width * (stages[i + 1]["count"] / max_count)
        else:
            w_bot = max_width * (stage["count"] / max_count) * 0.7

        prev_w_bot = w_bot

        # Trapezoid
        x1_top = cx - w_top / 2
        x2_top = cx + w_top / 2
        x1_bot = cx - w_bot / 2
        x2_bot = cx + w_bot / 2

        svg.append(
            f'<polygon points="{x1_top:.1f},{y_top:.1f} {x2_top:.1f},{y_top:.1f} '
            f'{x2_bot:.1f},{y_bot:.1f} {x1_bot:.1f},{y_bot:.1f}" '
            f'fill="{colors[i]}" stroke="#fff" stroke-width="2"/>'
        )

        # Text: dark on light, white on dark
        t = i / max(n - 1, 1)
        txt_color = "#0e4f6e" if t < 0.5 else "#fff"
        y_mid = (y_top + y_bot) / 2

        svg.append(
            f'<text x="{cx}" y="{y_mid - 4:.1f}" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="{txt_color}">{stage["label"]}</text>'
        )
        svg.append(
            f'<text x="{cx}" y="{y_mid + 14:.1f}" text-anchor="middle" '
            f'font-size="12" fill="{txt_color}" opacity="0.85">'
            f'{stage["count"]} / {stage["total"]}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate molecules
# ------------------------------------------------------------------

@tool
@env.task(cache="auto")
async def load_molecules(
    molecules_json: str = "",
) -> flyte.io.Dir:
    """Parse SMILES strings, validate with RDKit, generate 2D depictions.

    Args:
        molecules_json: JSON string mapping molecule names to SMILES.
            Defaults to a curated library of ~15 well-known drugs.

    Returns:
        flyte.io.Dir containing molecule data (JSON + PNG depictions).
        Pass this directory to compute_properties and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Draw

    if molecules_json.strip():
        molecules = json.loads(molecules_json)
    else:
        molecules = DEFAULT_MOLECULES

    out_dir = tempfile.mkdtemp(prefix="mol_library_")
    results = []
    valid_count = 0
    invalid_count = 0

    log.info(f"Parsing {len(molecules)} molecules...")

    for name, smiles in molecules.items():
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            log.warning(f"  [INVALID] {name}: {smiles}")
            invalid_count += 1
            continue

        valid_count += 1

        # Generate 2D depiction as PNG
        img = Draw.MolToImage(mol, size=(300, 300))
        img_path = os.path.join(out_dir, f"{name.replace(' ', '_')}.png")
        img.save(img_path)

        results.append({
            "name": name,
            "smiles": smiles,
            "valid": True,
            "image_file": os.path.basename(img_path),
        })

    # Save molecule manifest
    manifest = {
        "total": len(molecules),
        "valid": valid_count,
        "invalid": invalid_count,
        "molecules": results,
    }
    manifest_path = os.path.join(out_dir, "manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    log.info(f"Loaded {valid_count} valid molecules ({invalid_count} invalid)")

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Compute physicochemical properties
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def compute_properties(
    molecule_dir: flyte.io.Dir,
) -> str:
    """Compute drug-likeness properties for all molecules.

    Computes MW, LogP, HBD, HBA, TPSA, rotatable bonds, formal charge,
    ring count, QED, and Lipinski Rule of Five compliance.

    Args:
        molecule_dir: Directory from load_molecules.

    Returns:
        JSON string with all computed properties. Pass to screen_candidates
        and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Lipinski
    from rdkit.Chem.QED import qed

    # --- Loading report ---
    await flyte.report.replace.aio(
        _wrap_report("<h2>Computing Molecular Properties...</h2>"
                      "<p>Analyzing physicochemical descriptors for all molecules.</p>"),
        do_flush=True,
    )

    mol_dir = await molecule_dir.download()
    with open(os.path.join(mol_dir, "manifest.json")) as f:
        manifest = json.load(f)

    molecules_data = []
    lipinski_pass = 0

    for mol_info in manifest["molecules"]:
        mol = Chem.MolFromSmiles(mol_info["smiles"])
        if mol is None:
            continue

        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        hbd = Lipinski.NumHDonors(mol)
        hba = Lipinski.NumHAcceptors(mol)
        tpsa = Descriptors.TPSA(mol)
        rotatable = Lipinski.NumRotatableBonds(mol)
        formal_charge = Chem.GetFormalCharge(mol)
        num_rings = Lipinski.RingCount(mol)
        qed_score = qed(mol)

        # Lipinski Rule of Five
        lipinski = {
            "mw_ok": mw <= 500,
            "logp_ok": logp <= 5,
            "hbd_ok": hbd <= 5,
            "hba_ok": hba <= 10,
        }
        lipinski_all = all(lipinski.values())
        if lipinski_all:
            lipinski_pass += 1

        # Read image for data URI
        img_path = os.path.join(mol_dir, mol_info["image_file"])
        data_uri = ""
        if os.path.exists(img_path):
            with open(img_path, "rb") as img_f:
                b64 = base64.b64encode(img_f.read()).decode()
                data_uri = f"data:image/png;base64,{b64}"

        molecules_data.append({
            "name": mol_info["name"],
            "smiles": mol_info["smiles"],
            "mw": round(mw, 2),
            "logp": round(logp, 2),
            "hbd": hbd,
            "hba": hba,
            "tpsa": round(tpsa, 2),
            "rotatable_bonds": rotatable,
            "formal_charge": formal_charge,
            "num_rings": num_rings,
            "qed": round(qed_score, 4),
            "lipinski": lipinski,
            "lipinski_pass": lipinski_all,
            "image_data_uri": data_uri,
        })

    total = len(molecules_data)
    avg_mw = sum(m["mw"] for m in molecules_data) / total if total else 0
    avg_logp = sum(m["logp"] for m in molecules_data) / total if total else 0
    lipinski_rate = lipinski_pass / total * 100 if total else 0

    # ---- Build report ----
    html_parts = []

    # Header
    html_parts.append("<h2>Molecular Properties Analysis</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Total Molecules"),
        (f"{lipinski_rate:.0f}%", "Lipinski Pass Rate"),
        (f"{avg_mw:.1f}", "Avg. MW (Da)"),
        (f"{avg_logp:.2f}", "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Molecule gallery
    html_parts.append("<h3>Molecule Library</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in molecules_data:
        if m["image_data_uri"]:
            badge_class = "badge-success" if m["lipinski_pass"] else "badge-danger"
            badge_text = "Lipinski Pass" if m["lipinski_pass"] else "Lipinski Fail"
            html_parts.append(
                f'<div class="molecule-card" style="text-align:center;">'
                f'<img src="{m["image_data_uri"]}" style="width:160px;height:160px;object-fit:contain;"/>'
                f'<div style="font-weight:600;margin-top:6px;color:#0e4f6e;">{m["name"]}</div>'
                f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f}</div>'
                f'<div><span class="badge {badge_class}">{badge_text}</span></div>'
                f'</div>'
            )
    html_parts.append("</div>")

    # MW bar chart (horizontal, sorted)
    sorted_by_mw = sorted(molecules_data, key=lambda m: m["mw"], reverse=True)
    mw_labels = [m["name"] for m in sorted_by_mw]
    mw_vals = [m["mw"] for m in sorted_by_mw]
    mw_chart = _make_bar_chart(
        mw_labels, {"MW (Da)": mw_vals},
        title="Molecular Weight Distribution",
        horizontal=True,
        width=700, height=max(300, len(mw_labels) * 30 + 80),
        value_fmt=".1f",
    )
    html_parts.append("<h3>Molecular Weight</h3>")
    html_parts.append(f'<div class="chart-container">{mw_chart}</div>')

    # LogP vs MW scatter plot
    scatter_points = [
        {"x": m["mw"], "y": m["logp"], "label": m["name"]}
        for m in molecules_data
    ]
    scatter_chart = _make_scatter_plot(
        scatter_points,
        x_label="Molecular Weight (Da)",
        y_label="LogP",
        title="LogP vs. Molecular Weight (Lipinski Boundaries)",
        reference_lines=[
            {"axis": "x", "value": 500, "label": "MW = 500"},
            {"axis": "y", "value": 5, "label": "LogP = 5"},
        ],
        width=700,
        height=420,
    )
    html_parts.append("<h3>Lipinski Space</h3>")
    html_parts.append(f'<div class="chart-container">{scatter_chart}</div>')

    # Property heatmap (molecules x properties)
    prop_names = ["MW", "LogP", "HBD", "HBA", "TPSA", "Rot. Bonds"]
    # Normalize each property to 0-1 for heatmap
    raw_matrix = []
    for m in molecules_data:
        raw_matrix.append([m["mw"], m["logp"], m["hbd"], m["hba"], m["tpsa"], m["rotatable_bonds"]])

    # Normalize per column
    n_props = len(prop_names)
    col_min = [min(row[c] for row in raw_matrix) for c in range(n_props)]
    col_max = [max(row[c] for row in raw_matrix) for c in range(n_props)]
    norm_matrix = []
    for row in raw_matrix:
        norm_row = []
        for c in range(n_props):
            rng = col_max[c] - col_min[c]
            norm_row.append((row[c] - col_min[c]) / rng if rng else 0.5)
        norm_matrix.append(norm_row)

    heatmap_labels = [m["name"] for m in molecules_data]
    heatmap = _make_heatmap(
        norm_matrix, heatmap_labels, prop_names,
        title="Normalized Property Heatmap",
        color_scale="cyan",
        width=700,
        height=max(400, len(heatmap_labels) * 28 + 100),
    )
    html_parts.append("<h3>Property Heatmap</h3>")
    html_parts.append(f'<div class="chart-container">{heatmap}</div>')

    # Lipinski compliance table
    html_parts.append("<h3>Lipinski Rule of Five Compliance</h3>")
    html_parts.append("<table><tr><th>Molecule</th><th>MW &le; 500</th>"
                      "<th>LogP &le; 5</th><th>HBD &le; 5</th>"
                      "<th>HBA &le; 10</th><th>Overall</th></tr>")
    for m in molecules_data:
        lip = m["lipinski"]

        def _badge(ok):
            if ok:
                return '<span class="badge badge-success">Pass</span>'
            return '<span class="badge badge-danger">Fail</span>'

        overall_badge = _badge(m["lipinski_pass"])
        html_parts.append(
            f'<tr><td><strong>{m["name"]}</strong></td>'
            f'<td>{_badge(lip["mw_ok"])}</td>'
            f'<td>{_badge(lip["logp_ok"])}</td>'
            f'<td>{_badge(lip["hbd_ok"])}</td>'
            f'<td>{_badge(lip["hba_ok"])}</td>'
            f'<td>{overall_badge}</td></tr>'
        )
    html_parts.append("</table>")

    # QED bar chart
    sorted_by_qed = sorted(molecules_data, key=lambda m: m["qed"], reverse=True)
    qed_labels = [m["name"] for m in sorted_by_qed]
    qed_vals = [m["qed"] for m in sorted_by_qed]
    qed_chart = _make_bar_chart(
        qed_labels, {"QED Score": qed_vals},
        title="Drug-likeness (QED Score)",
        horizontal=True,
        width=700, height=max(300, len(qed_labels) * 30 + 80),
        value_fmt=".3f",
        colors=["#06d6a0"],
    )
    html_parts.append("<h3>Drug-likeness (QED)</h3>")
    html_parts.append(f'<div class="chart-container">{qed_chart}</div>')

    # Flush full report
    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # Return properties as JSON (strip image data URIs to reduce size)
    output = {
        "total": total,
        "lipinski_pass_count": lipinski_pass,
        "lipinski_pass_rate": round(lipinski_rate, 2),
        "avg_mw": round(avg_mw, 2),
        "avg_logp": round(avg_logp, 2),
        "molecules": [
            {k: v for k, v in m.items() if k != "image_data_uri"}
            for m in molecules_data
        ],
    }
    return json.dumps(output)

# ------------------------------------------------------------------
# Task 3: Screen candidates against target profile
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def screen_candidates(
    properties_json: str,
    target_profile: str = "",
) -> str:
    """Screen molecules against a target drug profile and rank candidates.

    Scores each molecule on how well it matches the target profile, computes
    pairwise Tanimoto similarity, and produces a ranked list.

    Args:
        properties_json: JSON from compute_properties.
        target_profile: JSON string with desired property ranges
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5.0]}).

    Returns:
        JSON string with ranked_molecules, similarity_matrix, similarity_labels,
        funnel, and target_profile. Pass the full return value verbatim to
        generate_report along with molecule_dir and properties_json.
    """
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Screening Candidates...</h2>"
                      "<p>Evaluating molecules against the target drug profile.</p>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    molecules = props["molecules"]

    # Default target profile
    if target_profile.strip():
        profile = json.loads(target_profile)
    else:
        profile = {
            "mw": [150, 500],
            "logp": [-0.5, 5.0],
            "hbd": [0, 5],
            "hba": [0, 10],
            "tpsa": [20, 140],
        }

    # --- Screening ---
    funnel_total = len(molecules)
    pass_mw = 0
    pass_logp = 0
    pass_lipinski = 0
    final_candidates = 0

    scored = []
    for m in molecules:
        score = 0
        max_score = 0
        criteria = {}

        # Check each profile criterion
        checks = [
            ("mw", m["mw"]),
            ("logp", m["logp"]),
            ("hbd", m["hbd"]),
            ("hba", m["hba"]),
            ("tpsa", m["tpsa"]),
        ]

        for key, val in checks:
            if key in profile:
                lo, hi = profile[key]
                max_score += 1
                in_range = lo <= val <= hi
                criteria[key] = in_range
                if in_range:
                    score += 1
                    # Bonus: closer to midpoint = higher score
                    mid = (lo + hi) / 2
                    rng = (hi - lo) / 2
                    dist = abs(val - mid) / rng if rng else 0
                    score += max(0, 0.5 * (1 - dist))

        # QED bonus
        score += m["qed"] * 2
        max_score += 2

        # Lipinski bonus
        if m["lipinski_pass"]:
            score += 1
        max_score += 1

        normalized_score = score / max_score if max_score else 0

        # Funnel tracking — cascading filter (each stage requires passing the previous)
        mw_ok = criteria.get("mw", True)
        logp_ok = criteria.get("logp", True)
        if mw_ok:
            pass_mw += 1
            if logp_ok:
                pass_logp += 1
                if m["lipinski_pass"]:
                    pass_lipinski += 1
                    if all(criteria.values()):
                        final_candidates += 1

        scored.append({
            **m,
            "screening_score": round(normalized_score, 4),
            "criteria_met": criteria,
            "all_criteria_met": all(criteria.values()),
        })

    # Sort by score descending
    scored.sort(key=lambda m: m["screening_score"], reverse=True)

    # --- Tanimoto similarity matrix ---
    fps = []
    valid_names = []
    for m in scored:
        mol = Chem.MolFromSmiles(m["smiles"])
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            fps.append(fp)
            valid_names.append(m["name"])

    similarity_matrix = []
    for i in range(len(fps)):
        row = []
        for j in range(len(fps)):
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            row.append(round(sim, 3))
        similarity_matrix.append(row)

    # ---- Build report ----
    html_parts = []
    html_parts.append("<h2>Candidate Screening Results</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(funnel_total), "Total Screened"),
        (str(pass_lipinski), "Lipinski Passes"),
        (str(final_candidates), "All Criteria Met"),
        (f"{scored[0]['screening_score']:.3f}" if scored else "N/A", "Top Score"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Screening funnel
    funnel_stages = [
        {"label": "Total Molecules", "count": funnel_total, "total": funnel_total},
        {"label": "Pass MW Filter", "count": pass_mw, "total": funnel_total},
        {"label": "Pass LogP Filter", "count": pass_logp, "total": funnel_total},
        {"label": "Lipinski Compliant", "count": pass_lipinski, "total": funnel_total},
        {"label": "All Criteria Met", "count": final_candidates, "total": funnel_total},
    ]
    funnel_svg = _make_funnel(
        funnel_stages,
        title="Screening Funnel",
        width=600,
        height=380,
    )
    html_parts.append("<h3>Screening Funnel</h3>")
    html_parts.append(f'<div class="chart-container" style="text-align:center;">{funnel_svg}</div>')

    # Ranked candidates table
    html_parts.append("<h3>Ranked Candidates</h3>")
    html_parts.append(
        "<table><tr><th>Rank</th><th>Molecule</th><th>Score</th>"
        "<th>MW</th><th>LogP</th><th>QED</th><th>Lipinski</th><th>All Criteria</th></tr>"
    )
    for rank, m in enumerate(scored, 1):
        lip_badge = ('<span class="badge badge-success">Pass</span>'
                     if m["lipinski_pass"]
                     else '<span class="badge badge-danger">Fail</span>')
        crit_badge = ('<span class="badge badge-success">Pass</span>'
                      if m["all_criteria_met"]
                      else '<span class="badge badge-danger">Fail</span>')
        # Highlight top 3
        row_style = ' style="background:#ecfeff;font-weight:600;"' if rank <= 3 else ""
        html_parts.append(
            f"<tr{row_style}><td>{rank}</td><td>{m['name']}</td>"
            f"<td>{m['screening_score']:.3f}</td>"
            f"<td>{m['mw']:.1f}</td><td>{m['logp']:.2f}</td>"
            f"<td>{m['qed']:.3f}</td><td>{lip_badge}</td><td>{crit_badge}</td></tr>"
        )
    html_parts.append("</table>")

    # Top 5 candidate cards with structures
    html_parts.append("<h3>Top 5 Candidates</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in scored[:5]:
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(250, 250)) if mol else ""
        badge_class = "badge-success" if m["all_criteria_met"] else "badge-info"
        badge_text = "All Criteria Met" if m["all_criteria_met"] else "Partial Match"
        html_parts.append(
            f'<div class="molecule-card" style="text-align:center;">'
            f'<img src="{img_uri}" style="width:140px;height:140px;object-fit:contain;"/>'
            f'<div style="font-weight:700;margin-top:6px;color:#0e4f6e;font-size:1.05em;">{m["name"]}</div>'
            f'<div style="font-size:0.85em;color:#155e75;margin:4px 0;">Score: {m["screening_score"]:.3f}</div>'
            f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f} | QED: {m["qed"]:.3f}</div>'
            f'<div style="margin-top:4px;"><span class="badge {badge_class}">{badge_text}</span></div>'
            f'</div>'
        )
    html_parts.append("</div>")

    # Tanimoto similarity heatmap
    if similarity_matrix:
        sim_heatmap = _make_heatmap(
            similarity_matrix, valid_names, valid_names,
            title="Pairwise Tanimoto Similarity (Morgan Fingerprints)",
            color_scale="cyan",
            width=700,
            height=max(500, len(valid_names) * 32 + 100),
        )
        html_parts.append("<h3>Chemical Similarity</h3>")
        html_parts.append(f'<div class="chart-container">{sim_heatmap}</div>')

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    output = {
        "ranked_molecules": scored,
        "similarity_matrix": similarity_matrix,
        "similarity_labels": valid_names,
        "funnel": funnel_stages,
        "target_profile": profile,
    }
    return json.dumps(output)

def _parse_screening_json(screening_json: str) -> dict:
    """Parse screening JSON from screen_candidates, with safe defaults.

    The agent must pass the exact tool return value. Partial or hand-built JSON
    is tolerated for optional similarity fields only.
    """
    screening = json.loads(screening_json)
    if "ranked_molecules" not in screening:
        raise ValueError(
            "screening_json must be the exact JSON string returned by "
            "screen_candidates (missing 'ranked_molecules'). Do not construct, "
            "truncate, or summarize tool output."
        )
    screening.setdefault("similarity_matrix", [])
    screening.setdefault("similarity_labels", [])
    return screening

# ------------------------------------------------------------------
# Task 4: Generate final comprehensive report
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def generate_report(
    molecule_dir: flyte.io.Dir,
    properties_json: str,
    screening_json: str,
) -> str:
    """Generate a comprehensive drug screening report.

    Produces an executive summary, top candidate spotlight cards, property
    distributions, chemical diversity analysis, and final recommendation.

    Args:
        molecule_dir: Directory from load_molecules.
        properties_json: JSON from compute_properties.
        screening_json: Exact verbatim JSON string returned by screen_candidates
            (must include ranked_molecules, similarity_matrix, similarity_labels).
            Do not construct or summarize this payload yourself.

    Returns:
        JSON summary with total_screened, lipinski_passes, all_criteria_met,
        top_candidate, top_score, and top_3 ranked molecules.
    """
    from rdkit import Chem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Generating Final Report...</h2>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    screening = _parse_screening_json(screening_json)
    ranked = screening["ranked_molecules"]
    sim_matrix = screening["similarity_matrix"]
    sim_labels = screening["similarity_labels"]

    total = props["total"]
    lipinski_pass = props["lipinski_pass_count"]
    all_criteria = sum(1 for m in ranked if m["all_criteria_met"])
    top = ranked[0] if ranked else None

    html_parts = []

    # --- Executive Summary ---
    html_parts.append("<h2>Drug Molecule Screening Report</h2>")
    top_name = top["name"] if top else "N/A"
    top_score = f'{top["screening_score"]:.3f}' if top else "N/A"
    html_parts.append(
        f'<div class="card">'
        f'<h3 style="margin-top:0;color:#0e4f6e;">Executive Summary</h3>'
        f'<p style="font-size:1.05em;">'
        f'<strong>{total}</strong> molecules were screened against the target drug profile. '
        f'<strong>{lipinski_pass}</strong> passed Lipinski\'s Rule of Five, and '
        f'<strong>{all_criteria}</strong> met all screening criteria. '
        f'The top candidate is <strong style="color:#0891b2;">{top_name}</strong> '
        f'with a screening score of <strong>{top_score}</strong>.</p>'
        f'</div>'
    )

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Molecules Screened"),
        (str(lipinski_pass), "Lipinski Passes"),
        (str(all_criteria), "All Criteria Met"),
        (top_score, "Top Score"),
        (f'{props["avg_mw"]:.0f} Da', "Avg. Molecular Weight"),
        (f'{props["avg_logp"]:.2f}', "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # --- Top 3 Candidate Spotlights ---
    html_parts.append("<h2>Top Candidate Spotlights</h2>")

    for rank, m in enumerate(ranked[:3], 1):
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(300, 300)) if mol else ""

        medal = ["gold", "silver", "#cd7f32"][rank - 1]
        medal_emoji = ["1st", "2nd", "3rd"][rank - 1]

        lip_badges = ""
        for rule, key in [("MW", "mw_ok"), ("LogP", "logp_ok"),
                          ("HBD", "hbd_ok"), ("HBA", "hba_ok")]:
            ok = m["lipinski"].get(key, False)
            cls = "badge-success" if ok else "badge-danger"
            lip_badges += f'<span class="badge {cls}" style="margin:2px;">{rule}</span> '

        html_parts.append(
            f'<div class="molecule-card" style="display:flex;gap:20px;align-items:flex-start;flex-wrap:wrap;">'
            f'<div style="text-align:center;min-width:180px;">'
            f'<div style="font-size:1.6em;font-weight:800;color:{medal};">{medal_emoji}</div>'
            f'<img src="{img_uri}" style="width:200px;height:200px;object-fit:contain;border-radius:8px;'
            f'border:2px solid #a5f3fc;"/>'
            f'<div style="font-weight:700;font-size:1.1em;color:#0e4f6e;margin-top:8px;">{m["name"]}</div>'
            f'</div>'
            f'<div style="flex:1;min-width:280px;">'
            f'<table style="margin:0;">'
            f'<tr><td><strong>SMILES</strong></td><td style="font-family:monospace;font-size:0.8em;word-break:break-all;">{m["smiles"]}</td></tr>'
            f'<tr><td><strong>Screening Score</strong></td><td style="font-weight:700;color:#0891b2;font-size:1.1em;">{m["screening_score"]:.3f}</td></tr>'
            f'<tr><td><strong>Molecular Weight</strong></td><td>{m["mw"]:.1f} Da</td></tr>'
            f'<tr><td><strong>LogP</strong></td><td>{m["logp"]:.2f}</td></tr>'
            f'<tr><td><strong>H-Bond Donors</strong></td><td>{m["hbd"]}</td></tr>'
            f'<tr><td><strong>H-Bond Acceptors</strong></td><td>{m["hba"]}</td></tr>'
            f'<tr><td><strong>TPSA</strong></td><td>{m["tpsa"]:.1f} A&sup2;</td></tr>'
            f'<tr><td><strong>Rotatable Bonds</strong></td><td>{m["rotatable_bonds"]}</td></tr>'
            f'<tr><td><strong>QED</strong></td><td>{m["qed"]:.4f}</td></tr>'
            f'<tr><td><strong>Lipinski Compliance</strong></td><td>{lip_badges}</td></tr>'
            f'</table>'
            f'</div>'
            f'</div>'
        )

    # --- Property Distribution (box-plot style as bars with min/max/median) ---
    html_parts.append("<h2>Property Distributions</h2>")

    prop_keys = [("mw", "Molecular Weight (Da)"), ("logp", "LogP"),
                 ("tpsa", "TPSA"), ("qed", "QED Score")]
    for key, label in prop_keys:
        vals = sorted([m[key] for m in ranked])
        n = len(vals)
        if n == 0:
            continue
        v_min = vals[0]
        v_max = vals[-1]
        median = vals[n // 2] if n % 2 == 1 else (vals[n // 2 - 1] + vals[n // 2]) / 2
        q1 = vals[n // 4] if n >= 4 else v_min
        q3 = vals[3 * n // 4] if n >= 4 else v_max

        # Simple horizontal box-plot as SVG
        box_w = 500
        box_h = 50
        margin_l = 10
        v_range = v_max - v_min or 1

        def sx(v):
            return margin_l + ((v - v_min) / v_range) * (box_w - 2 * margin_l)

        box_svg = (
            f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {box_w} {box_h}" '
            f'style="width:100%;max-width:{box_w}px;height:auto;">'
            f'<rect width="{box_w}" height="{box_h}" fill="#fff" rx="4"/>'
            # Whisker line
            f'<line x1="{sx(v_min):.1f}" y1="25" x2="{sx(v_max):.1f}" y2="25" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Min whisker
            f'<line x1="{sx(v_min):.1f}" y1="18" x2="{sx(v_min):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Max whisker
            f'<line x1="{sx(v_max):.1f}" y1="18" x2="{sx(v_max):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # IQR box
            f'<rect x="{sx(q1):.1f}" y="14" width="{sx(q3) - sx(q1):.1f}" height="22" '
            f'fill="#a5f3fc" stroke="#0891b2" stroke-width="1.5" rx="3"/>'
            # Median line
            f'<line x1="{sx(median):.1f}" y1="12" x2="{sx(median):.1f}" y2="38" '
            f'stroke="#0e4f6e" stroke-width="2"/>'
            # Labels
            f'<text x="{sx(v_min):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_min:.1f}</text>'
            f'<text x="{sx(median):.1f}" y="10" text-anchor="middle" font-size="9" fill="#0e4f6e" font-weight="600">{median:.1f}</text>'
            f'<text x="{sx(v_max):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_max:.1f}</text>'
            f'</svg>'
        )
        html_parts.append(
            f'<div style="margin:8px 0;"><strong style="color:#155e75;">{label}</strong>'
            f'<div class="chart-container" style="padding:8px;">{box_svg}</div></div>'
        )

    # --- Chemical Diversity ---
    html_parts.append("<h2>Chemical Diversity Analysis</h2>")

    if sim_matrix and len(sim_matrix) > 1:
        # Compute average pairwise similarity (off-diagonal)
        n_mols = len(sim_matrix)
        off_diag = []
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                off_diag.append(sim_matrix[i][j])

        avg_sim = sum(off_diag) / len(off_diag) if off_diag else 0
        max_sim = max(off_diag) if off_diag else 0
        min_sim = min(off_diag) if off_diag else 0

        # Find most similar pair
        best_i, best_j = 0, 1
        best_val = 0
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                if sim_matrix[i][j] > best_val:
                    best_val = sim_matrix[i][j]
                    best_i, best_j = i, j

        html_parts.append('<div class="stat-grid">')
        html_parts.append(
            f'<div class="stat"><div class="value">{avg_sim:.3f}</div>'
            f'<div class="label">Avg. Pairwise Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{min_sim:.3f}</div>'
            f'<div class="label">Min Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{max_sim:.3f}</div>'
            f'<div class="label">Max Similarity</div></div>'
        )
        html_parts.append("</div>")

        diversity_text = "highly diverse" if avg_sim < 0.3 else "moderately diverse" if avg_sim < 0.5 else "relatively similar"
        html_parts.append(
            f'<div class="note">'
            f'The library is <strong>{diversity_text}</strong> (avg. Tanimoto = {avg_sim:.3f}). '
            f'The most similar pair is <strong>{sim_labels[best_i]}</strong> and '
            f'<strong>{sim_labels[best_j]}</strong> (similarity = {best_val:.3f}).</div>'
        )

    # --- Recommendation ---
    html_parts.append("<h2>Recommendation</h2>")
    if top:
        html_parts.append(
            f'<div class="card">'
            f'<h3 style="margin-top:0;color:#0891b2;">Top Candidate: {top["name"]}</h3>'
            f'<p>Based on the virtual screening analysis, <strong>{top["name"]}</strong> '
            f'achieved the highest composite screening score of <strong>{top["screening_score"]:.3f}</strong>. '
        )

        reasons = []
        if top["lipinski_pass"]:
            reasons.append("full Lipinski Rule of Five compliance")
        if top["qed"] > 0.5:
            reasons.append(f"high drug-likeness (QED = {top['qed']:.3f})")
        if top.get("all_criteria_met"):
            reasons.append("all target profile criteria met")
        if top["mw"] <= 500:
            reasons.append(f"favorable molecular weight ({top['mw']:.1f} Da)")

        if reasons:
            html_parts.append(
                f'This candidate stands out due to: {", ".join(reasons)}.</p>'
            )
        else:
            html_parts.append("</p>")

        # Runner-up mentions
        if len(ranked) >= 2:
            html_parts.append(
                f'<p style="font-size:0.9em;color:#6c757d;">Runner-up candidates: '
            )
            runners = []
            for m in ranked[1:4]:
                runners.append(f'{m["name"]} (score: {m["screening_score"]:.3f})')
            html_parts.append(", ".join(runners) + ".</p>")

        html_parts.append("</div>")

    # Final note
    html_parts.append(
        '<div class="note">'
        "This is a virtual screening analysis. All candidates should undergo "
        "further computational validation (molecular dynamics, docking) and "
        "experimental testing before advancing to clinical trials.</div>"
    )

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # JSON summary
    summary = {
        "total_screened": total,
        "lipinski_passes": lipinski_pass,
        "all_criteria_met": all_criteria,
        "top_candidate": top["name"] if top else None,
        "top_score": top["screening_score"] if top else None,
        "top_3": [
            {"name": m["name"], "score": m["screening_score"]}
            for m in ranked[:3]
        ],
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Agent
# ------------------------------------------------------------------

# {{docs-fragment agent}}
SCREENING_AGENT_INSTRUCTIONS = """\
You are a medicinal chemistry screening strategist. You orchestrate a virtual \
screening pipeline using durable Flyte tools. You NEVER invent molecular \
properties — only RDKit tools compute them.

Workflow:
1. If target_profile is not provided in the user message, derive a JSON \
target_profile from the therapeutic brief. Valid keys: mw, logp, hbd, hba, tpsa \
(each [min, max]). Ground choices in oral bioavailability / kinase / CNS rules \
as appropriate to the brief.
2. First pass (always): load_molecules → compute_properties → \
screen_candidates → generate_report. Pass tool outputs between steps exactly \
(molecule_dir from load_molecules into compute_properties and generate_report; \
properties_json from compute_properties into screen_candidates and \
generate_report; screening_json must be the complete, unmodified string \
returned by screen_candidates — never rebuild or summarize JSON yourself).
3. Read the JSON summary returned by generate_report. Reflect:
   - If all_criteria_met == 0: relax exactly ONE profile bound by ~10–20% \
and re-run screen_candidates then generate_report only, reusing the same \
molecule_dir and properties_json from the first pass.
   - If all molecules pass but diversity is a stated goal: note high similarity \
in your summary; do not re-run unless brief asks for stricter filters.
   - Maximum ONE rescreen iteration.
4. Finish with plain text: top candidate, rationale tied to computed metrics \
from the tool JSON, funnel interpretation, and suggested next steps (docking, \
ADMET lab tests).

If the user supplies an explicit target_profile JSON, use it as-is.

Do NOT ask the user for SMILES or molecule lists when molecules_json is empty — \
the default library is loaded automatically.
"""

screening_agent = Agent(
    name="drug-screening-agent",
    instructions=SCREENING_AGENT_INSTRUCTIONS,
    model=MODEL,
    tools=[
        load_molecules,
        compute_properties,
        screen_candidates,
        generate_report,
    ],
    max_turns=12,
)
# {{/docs-fragment agent}}

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def pipeline(
    brief: str = "Screen the default drug library for orally bioavailable small molecules.",
    molecules_json: str = "",
    target_profile: str = "",
) -> str:
    """Agentic virtual drug molecule screening pipeline.

    A medicinal-chemistry agent interprets the screening brief, derives or
    applies a target profile, orchestrates the RDKit screening stages, and
    optionally re-screens when funnel results are too narrow.

    Args:
        brief: Natural-language therapeutic goal (e.g. oral kinase inhibitors,
            CNS-penetrant small molecules).
        molecules_json: JSON mapping molecule names to SMILES strings.
            Defaults to a curated library of ~15 well-known drugs.
        target_profile: Optional JSON with desired property ranges that
            overrides agent-derived criteria
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5]}).

    Returns:
        Agent summary with screening rationale and key results.
    """
    prompt_parts = [
        f"Screening brief: {brief}",
        'Use molecules_json="" for the built-in default library unless provided below.',
        "Compose the four stage tools in order: load_molecules → compute_properties "
        "→ screen_candidates → generate_report. Pass each tool's full return value "
        "verbatim to the next step (especially screening_json). Re-run "
        "screen_candidates and generate_report at most once if the funnel is too narrow.",
    ]
    if molecules_json.strip():
        prompt_parts.append(f"molecules_json: {molecules_json}")
    if target_profile.strip():
        prompt_parts.append(f"Use this target_profile exactly: {target_profile}")

    result = await screening_agent.run.aio("\n".join(prompt_parts))
    return result.summary or result.error or ""

# {{/docs-fragment pipeline}}

# ------------------------------------------------------------------
# Rescreen demo — tight profile + explicit rescreen instructions
# ------------------------------------------------------------------

# Initial profile is deliberately strict (narrow MW + low LogP cap) so
# all_criteria_met is typically 0 on the default library; the brief then
# forces a single rescreen with a widened LogP window.
RESCREEN_DEMO_TARGET_PROFILE = (
    '{"mw": [150, 200], "logp": [-0.5, 1.0], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_TARGET_PROFILE_RESCREEN = (
    '{"mw": [150, 200], "logp": [-0.5, 3.5], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_BRIEF = f"""\
Two-round agentic screening demo on the default library.

**Round 1 (strict profile):** load_molecules → compute_properties → \
screen_candidates → generate_report using the initial target_profile exactly.

**Round 2 (required — do not skip):** call screen_candidates then generate_report \
again, reusing the same molecule_dir and properties_json from round 1, with this \
relaxed target_profile (wider LogP window only): \
{RESCREEN_DEMO_TARGET_PROFILE_RESCREEN}

Pass every tool return value verbatim to the next step. After both rounds, \
summarize how the funnel and top candidates changed between round 1 and round 2."""

# {{docs-fragment rescreen_demo}}
@env.task(report=True)
async def rescreen_demo() -> str:
    """Example run with a two-round execution graph (rescreen).

    Round 1 uses a strict CNS-like profile; round 2 always re-runs
    screen_candidates and generate_report with a widened LogP window,
    reusing cached molecule_dir and properties_json.
    """
    return await pipeline(
        brief=RESCREEN_DEMO_BRIEF,
        target_profile=RESCREEN_DEMO_TARGET_PROFILE,
    )

# {{/docs-fragment rescreen_demo}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/drug_molecule_screening/drug_molecule_screening.py*

## Run the agentic pipeline

The `pipeline` task delegates to the screening agent:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.4",
#    "litellm",
#    "rdkit",
#    "numpy",
#    "scikit-learn",
#    "pillow",
# ]
# main = "pipeline"
# params = ""
# ///
"""Virtual drug molecule screening — compute properties, apply Lipinski filters, rank candidates."""

import base64
import io
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report
from flyte.ai.agents import Agent, tool

MODEL = os.getenv("DRUG_SCREENING_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="drug-molecule-screening", pre=True).with_apt_packages(
    "libxrender1", "libxext6", "libexpat1",
)

env = flyte.TaskEnvironment(
    name="drug-molecule-screening",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default molecule library — real SMILES for well-known drugs
# ------------------------------------------------------------------

DEFAULT_MOLECULES = {
    "Aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
    "Ibuprofen": "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "Penicillin G": "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C",
    "Metformin": "CN(C)C(=N)NC(=N)N",
    "Paracetamol": "CC(=O)NC1=CC=C(C=C1)O",
    "Diazepam": "ClC1=CC2=C(C=C1)N(C(=O)CN=C2C3=CC=CC=C3)C",
    "Omeprazole": "CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=CC=CC=C3N2",
    "Atorvastatin": "CC(C)C1=C(C(=C(N1CCC(CC(CC(=O)O)O)O)C2=CC=C(C=C2)F)C3=CC=CC=C3)C(=O)NC4=CC=CC=C4",
    "Methotrexate": "CN(CC1=CN=C2N=C(N=C(N)C2=N1)N)C3=CC=C(C=C3)C(=O)NC(CCC(=O)O)C(=O)O",
    "Doxorubicin": "CC1C(C(CC(O1)OC2CC(CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
    "Tamoxifen": "CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3",
    "Lopinavir": "CC1=C(C(=CC=C1)C)OCC(=O)NC(CC2=CC=CC=C2)C(CC(CC3=CC=CC=C3)NC(=O)C(C(C)C)N4CCCNC4=O)O",
    "Remdesivir": "CCC(CC)COC(=O)C(C)NP(=O)(OCC1C(C(C(O1)C2=CC=C3N2N=CN=C3N)O)O)OC4=CC=CC=C4",
    "Erlotinib": "COCCOC1=CC2=C(C=C1OCCOC)C(=NC=N2)NC3=CC=CC(=C3)C#C",
}

# ------------------------------------------------------------------
# Report styling — pharma blue/cyan theme
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #0e4f6e; border-bottom: 2px solid #0891b2; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #155e75; margin-top: 20px; }
  .report .card { background: #ecfeff; border: 1px solid #a5f3fc; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #cffafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0e4f6e; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0e4f6e; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #cffafe; }
  .report tr:nth-child(even) { background: #ecfeff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #cffafe; color: #155e75; }
  .report .chart-container { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #ecfeff; border-left: 4px solid #0891b2; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .molecule-card { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .molecule-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 12px; margin: 16px 0; }
  .report .funnel { text-align: center; margin: 24px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _mol_to_data_uri(mol, size: tuple[int, int] = (300, 300)) -> str:
    """Convert an RDKit molecule to a PNG base64 data URI."""
    from rdkit.Chem import Draw

    img = Draw.MolToImage(mol, size=size)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f"data:image/png;base64,{b64}"

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 340,
    y_max_cap: float | None = None,
    horizontal: bool = False,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels.
        series: Dict mapping series name to list of values.
        title: Chart title.
        colors: Colors for each series.
        width/height: SVG dimensions.
        y_max_cap: Cap the y-axis at this value.
        horizontal: If True, draw horizontal bars.
        value_fmt: Format string for value labels.

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#a5f3fc", "#155e75"]
    colors = colors or default_colors

    if horizontal:
        return _make_horizontal_bar_chart(labels, series, title, colors, width, height, value_fmt)

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_fmt}}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_fmt}}</text>'
            )
        # Truncate long labels
        disp_label = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 16}" '
            f'text-anchor="middle" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {gx + n_series * bar_width / 2:.1f}, {mt + ch + 16})">'
            f'{disp_label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 100
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 100}" y="{mt + ch + 40}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 51}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_horizontal_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 400,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG horizontal bar chart (sorted)."""
    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0"]
    colors = colors or default_colors

    n = len(labels)
    row_height = max(22, min(35, (height - 80) // max(n, 1)))
    actual_height = max(height, 80 + n * row_height)
    ml, mr, mt, mb = 120, 60, 40, 20
    cw = width - ml - mr
    ch = actual_height - mt - mb

    # Use first series
    first_key = list(series.keys())[0]
    vals = series[first_key]
    x_max = max(vals) * 1.15 if vals else 1

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {actual_height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{actual_height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    bar_h = row_height * 0.65
    for i, (label, val) in enumerate(zip(labels, vals)):
        y = mt + i * row_height
        bw = (val / x_max) * cw if x_max else 0
        color = colors[i % len(colors)]
        # Label
        disp = label if len(label) <= 14 else label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + bar_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#1a1a2e">{disp}</text>'
        )
        # Bar
        svg.append(
            f'<rect x="{ml}" y="{y:.1f}" width="{bw:.1f}" height="{bar_h:.1f}" '
            f'fill="{color}" rx="3"/>'
        )
        # Value
        svg.append(
            f'<text x="{ml + bw + 6:.1f}" y="{y + bar_h / 2 + 4:.1f}" '
            f'font-size="11" fill="#0e4f6e" font-weight="600">{val:{value_fmt}}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    color_scale: str = "cyan",
    width: int = 700,
    height: int = 500,
    value_fmt: str = ".2f",
) -> str:
    """Generate an SVG heatmap.

    Args:
        matrix: 2D list of values (rows x cols).
        row_labels: Labels for rows.
        col_labels: Labels for columns.
        title: Chart title.
        color_scale: Color scheme ("cyan", "red", "green").
        width/height: SVG dimensions.
        value_fmt: Format string for cell values.

    Returns:
        SVG string.
    """
    if not matrix or not matrix[0]:
        return ""

    n_rows = len(matrix)
    n_cols = len(matrix[0])

    ml, mr, mt, mb = 110, 20, 70, 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    # Flatten to find range
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    def color_for(v):
        t = (v - v_min) / v_range
        if color_scale == "cyan":
            # White to deep teal
            r = int(255 - t * (255 - 14))
            g = int(255 - t * (255 - 79))
            b = int(255 - t * (255 - 110))
        elif color_scale == "red":
            r = int(255 - t * 50)
            g = int(255 - t * 200)
            b = int(255 - t * 200)
        else:  # green
            r = int(255 - t * 200)
            g = int(255 - t * 50)
            b = int(255 - t * 200)
        return f"rgb({r},{g},{b})"

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Column labels (rotated)
    for ci, label in enumerate(col_labels):
        x = ml + ci * cell_w + cell_w / 2
        disp = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{x:.1f}" y="{mt - 8}" text-anchor="end" font-size="10" '
            f'fill="#1a1a2e" transform="rotate(-45, {x:.1f}, {mt - 8})">{disp}</text>'
        )

    # Row labels + cells
    for ri, (row_label, row_vals) in enumerate(zip(row_labels, matrix)):
        y = mt + ri * cell_h
        disp = row_label if len(row_label) <= 14 else row_label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + cell_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#1a1a2e">{disp}</text>'
        )
        for ci, val in enumerate(row_vals):
            x = ml + ci * cell_w
            fill = color_for(val)
            svg.append(
                f'<rect x="{x:.1f}" y="{y:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{fill}" stroke="#fff" stroke-width="1"/>'
            )
            # Text color: dark on light, light on dark
            t = (val - v_min) / v_range
            txt_color = "#fff" if t > 0.55 else "#1a1a2e"
            # Only show text if cells are large enough
            if cell_w > 30 and cell_h > 18:
                svg.append(
                    f'<text x="{x + cell_w / 2:.1f}" y="{y + cell_h / 2 + 4:.1f}" '
                    f'text-anchor="middle" font-size="9" fill="{txt_color}">'
                    f'{val:{value_fmt}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_scatter_plot(
    points: list[dict],
    x_label: str = "MW",
    y_label: str = "LogP",
    title: str = "",
    reference_lines: list[dict] | None = None,
    width: int = 700,
    height: int = 400,
) -> str:
    """Generate an SVG scatter plot.

    Args:
        points: List of dicts with "x", "y", "label" keys.
        x_label/y_label: Axis labels.
        title: Chart title.
        reference_lines: List of dicts with "axis" ("x"/"y"), "value", "label".
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not points:
        return ""

    ml, mr, mt, mb = 60, 30, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [p["x"] for p in points]
    y_vals = [p["y"] for p in points]
    x_min, x_max = min(x_vals) * 0.9, max(x_vals) * 1.1
    y_min, y_max = min(y_vals) - 1, max(y_vals) + 1

    # Extend ranges to include reference lines
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                x_max = max(x_max, rl["value"] * 1.1)
            else:
                y_max = max(y_max, rl["value"] * 1.1)

    x_range = x_max - x_min or 1
    y_range = y_max - y_min or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid
    for i in range(6):
        y_tick = y_min + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.1f}</text>'
        )

    for i in range(6):
        x_tick = x_min + x_range * i / 5
        px = sx(x_tick)
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
            f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Reference lines (Lipinski boundaries)
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                px = sx(rl["value"])
                svg.append(
                    f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{px + 4:.1f}" y="{mt + 14}" font-size="10" '
                    f'fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )
            else:
                py = sy(rl["value"])
                svg.append(
                    f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{ml + cw - 4:.1f}" y="{py - 6:.1f}" text-anchor="end" '
                    f'font-size="10" fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )

    # Drug-like zone shading (MW<=500 and LogP<=5 quadrant)
    if reference_lines:
        mw_line = next((rl for rl in reference_lines if rl["axis"] == "x"), None)
        logp_line = next((rl for rl in reference_lines if rl["axis"] == "y"), None)
        if mw_line and logp_line:
            zx1 = sx(x_min)
            zx2 = sx(min(mw_line["value"], x_max))
            zy1 = sy(min(logp_line["value"], y_max))
            zy2 = sy(y_min)
            svg.append(
                f'<rect x="{zx1:.1f}" y="{zy1:.1f}" '
                f'width="{zx2 - zx1:.1f}" height="{zy2 - zy1:.1f}" '
                f'fill="#0891b2" opacity="0.06" rx="4"/>'
            )
            svg.append(
                f'<text x="{zx1 + 8:.1f}" y="{zy2 - 8:.1f}" font-size="11" '
                f'fill="#0891b2" font-weight="600" opacity="0.6">Drug-like Zone</text>'
            )

    # Points
    point_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#155e75", "#0284c7",
                    "#059669", "#0d9488", "#0369a1", "#047857", "#115e59",
                    "#0c4a6e", "#064e3b", "#1e3a5f", "#134e4a", "#075985"]
    for i, pt in enumerate(points):
        px, py = sx(pt["x"]), sy(pt["y"])
        color = point_colors[i % len(point_colors)]
        svg.append(
            f'<circle cx="{px:.1f}" cy="{py:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5" opacity="0.85"/>'
        )
        # Label offset to avoid overlap
        offset_x = 8
        offset_y = -8 if i % 2 == 0 else 14
        label = pt["label"] if len(pt["label"]) <= 12 else pt["label"][:10] + ".."
        svg.append(
            f'<text x="{px + offset_x:.1f}" y="{py + offset_y:.1f}" '
            f'font-size="9" fill="#1a1a2e">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Axis labels
    if x_label:
        svg.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        svg.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_funnel(
    stages: list[dict],
    title: str = "",
    width: int = 600,
    height: int = 400,
) -> str:
    """Generate an SVG funnel visualization.

    Args:
        stages: List of dicts with "label", "count", "total" keys.
        title: Chart title.
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not stages:
        return ""

    n = len(stages)
    mt = 50
    mb = 20
    available_h = height - mt - mb
    stage_h = available_h / n
    cx = width / 2

    # Color gradient from light cyan to deep teal
    colors = []
    for i in range(n):
        t = i / max(n - 1, 1)
        r = int(207 - t * (207 - 14))
        g = int(250 - t * (250 - 79))
        b = int(254 - t * (254 - 110))
        colors.append(f"rgb({r},{g},{b})")

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{cx}" y="28" text-anchor="middle" '
            f'font-size="16" font-weight="700" fill="#0e4f6e">{title}</text>'
        )

    max_count = stages[0]["count"] if stages else 1
    max_width = width * 0.75

    for i, stage in enumerate(stages):
        y_top = mt + i * stage_h
        y_bot = y_top + stage_h

        # Width proportional to count
        w_top = max_width * (stage["count"] / max_count) if i == 0 else prev_w_bot
        if i < n - 1:
            w_bot = max_width * (stages[i + 1]["count"] / max_count)
        else:
            w_bot = max_width * (stage["count"] / max_count) * 0.7

        prev_w_bot = w_bot

        # Trapezoid
        x1_top = cx - w_top / 2
        x2_top = cx + w_top / 2
        x1_bot = cx - w_bot / 2
        x2_bot = cx + w_bot / 2

        svg.append(
            f'<polygon points="{x1_top:.1f},{y_top:.1f} {x2_top:.1f},{y_top:.1f} '
            f'{x2_bot:.1f},{y_bot:.1f} {x1_bot:.1f},{y_bot:.1f}" '
            f'fill="{colors[i]}" stroke="#fff" stroke-width="2"/>'
        )

        # Text: dark on light, white on dark
        t = i / max(n - 1, 1)
        txt_color = "#0e4f6e" if t < 0.5 else "#fff"
        y_mid = (y_top + y_bot) / 2

        svg.append(
            f'<text x="{cx}" y="{y_mid - 4:.1f}" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="{txt_color}">{stage["label"]}</text>'
        )
        svg.append(
            f'<text x="{cx}" y="{y_mid + 14:.1f}" text-anchor="middle" '
            f'font-size="12" fill="{txt_color}" opacity="0.85">'
            f'{stage["count"]} / {stage["total"]}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate molecules
# ------------------------------------------------------------------

@tool
@env.task(cache="auto")
async def load_molecules(
    molecules_json: str = "",
) -> flyte.io.Dir:
    """Parse SMILES strings, validate with RDKit, generate 2D depictions.

    Args:
        molecules_json: JSON string mapping molecule names to SMILES.
            Defaults to a curated library of ~15 well-known drugs.

    Returns:
        flyte.io.Dir containing molecule data (JSON + PNG depictions).
        Pass this directory to compute_properties and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Draw

    if molecules_json.strip():
        molecules = json.loads(molecules_json)
    else:
        molecules = DEFAULT_MOLECULES

    out_dir = tempfile.mkdtemp(prefix="mol_library_")
    results = []
    valid_count = 0
    invalid_count = 0

    log.info(f"Parsing {len(molecules)} molecules...")

    for name, smiles in molecules.items():
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            log.warning(f"  [INVALID] {name}: {smiles}")
            invalid_count += 1
            continue

        valid_count += 1

        # Generate 2D depiction as PNG
        img = Draw.MolToImage(mol, size=(300, 300))
        img_path = os.path.join(out_dir, f"{name.replace(' ', '_')}.png")
        img.save(img_path)

        results.append({
            "name": name,
            "smiles": smiles,
            "valid": True,
            "image_file": os.path.basename(img_path),
        })

    # Save molecule manifest
    manifest = {
        "total": len(molecules),
        "valid": valid_count,
        "invalid": invalid_count,
        "molecules": results,
    }
    manifest_path = os.path.join(out_dir, "manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    log.info(f"Loaded {valid_count} valid molecules ({invalid_count} invalid)")

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Compute physicochemical properties
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def compute_properties(
    molecule_dir: flyte.io.Dir,
) -> str:
    """Compute drug-likeness properties for all molecules.

    Computes MW, LogP, HBD, HBA, TPSA, rotatable bonds, formal charge,
    ring count, QED, and Lipinski Rule of Five compliance.

    Args:
        molecule_dir: Directory from load_molecules.

    Returns:
        JSON string with all computed properties. Pass to screen_candidates
        and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Lipinski
    from rdkit.Chem.QED import qed

    # --- Loading report ---
    await flyte.report.replace.aio(
        _wrap_report("<h2>Computing Molecular Properties...</h2>"
                      "<p>Analyzing physicochemical descriptors for all molecules.</p>"),
        do_flush=True,
    )

    mol_dir = await molecule_dir.download()
    with open(os.path.join(mol_dir, "manifest.json")) as f:
        manifest = json.load(f)

    molecules_data = []
    lipinski_pass = 0

    for mol_info in manifest["molecules"]:
        mol = Chem.MolFromSmiles(mol_info["smiles"])
        if mol is None:
            continue

        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        hbd = Lipinski.NumHDonors(mol)
        hba = Lipinski.NumHAcceptors(mol)
        tpsa = Descriptors.TPSA(mol)
        rotatable = Lipinski.NumRotatableBonds(mol)
        formal_charge = Chem.GetFormalCharge(mol)
        num_rings = Lipinski.RingCount(mol)
        qed_score = qed(mol)

        # Lipinski Rule of Five
        lipinski = {
            "mw_ok": mw <= 500,
            "logp_ok": logp <= 5,
            "hbd_ok": hbd <= 5,
            "hba_ok": hba <= 10,
        }
        lipinski_all = all(lipinski.values())
        if lipinski_all:
            lipinski_pass += 1

        # Read image for data URI
        img_path = os.path.join(mol_dir, mol_info["image_file"])
        data_uri = ""
        if os.path.exists(img_path):
            with open(img_path, "rb") as img_f:
                b64 = base64.b64encode(img_f.read()).decode()
                data_uri = f"data:image/png;base64,{b64}"

        molecules_data.append({
            "name": mol_info["name"],
            "smiles": mol_info["smiles"],
            "mw": round(mw, 2),
            "logp": round(logp, 2),
            "hbd": hbd,
            "hba": hba,
            "tpsa": round(tpsa, 2),
            "rotatable_bonds": rotatable,
            "formal_charge": formal_charge,
            "num_rings": num_rings,
            "qed": round(qed_score, 4),
            "lipinski": lipinski,
            "lipinski_pass": lipinski_all,
            "image_data_uri": data_uri,
        })

    total = len(molecules_data)
    avg_mw = sum(m["mw"] for m in molecules_data) / total if total else 0
    avg_logp = sum(m["logp"] for m in molecules_data) / total if total else 0
    lipinski_rate = lipinski_pass / total * 100 if total else 0

    # ---- Build report ----
    html_parts = []

    # Header
    html_parts.append("<h2>Molecular Properties Analysis</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Total Molecules"),
        (f"{lipinski_rate:.0f}%", "Lipinski Pass Rate"),
        (f"{avg_mw:.1f}", "Avg. MW (Da)"),
        (f"{avg_logp:.2f}", "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Molecule gallery
    html_parts.append("<h3>Molecule Library</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in molecules_data:
        if m["image_data_uri"]:
            badge_class = "badge-success" if m["lipinski_pass"] else "badge-danger"
            badge_text = "Lipinski Pass" if m["lipinski_pass"] else "Lipinski Fail"
            html_parts.append(
                f'<div class="molecule-card" style="text-align:center;">'
                f'<img src="{m["image_data_uri"]}" style="width:160px;height:160px;object-fit:contain;"/>'
                f'<div style="font-weight:600;margin-top:6px;color:#0e4f6e;">{m["name"]}</div>'
                f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f}</div>'
                f'<div><span class="badge {badge_class}">{badge_text}</span></div>'
                f'</div>'
            )
    html_parts.append("</div>")

    # MW bar chart (horizontal, sorted)
    sorted_by_mw = sorted(molecules_data, key=lambda m: m["mw"], reverse=True)
    mw_labels = [m["name"] for m in sorted_by_mw]
    mw_vals = [m["mw"] for m in sorted_by_mw]
    mw_chart = _make_bar_chart(
        mw_labels, {"MW (Da)": mw_vals},
        title="Molecular Weight Distribution",
        horizontal=True,
        width=700, height=max(300, len(mw_labels) * 30 + 80),
        value_fmt=".1f",
    )
    html_parts.append("<h3>Molecular Weight</h3>")
    html_parts.append(f'<div class="chart-container">{mw_chart}</div>')

    # LogP vs MW scatter plot
    scatter_points = [
        {"x": m["mw"], "y": m["logp"], "label": m["name"]}
        for m in molecules_data
    ]
    scatter_chart = _make_scatter_plot(
        scatter_points,
        x_label="Molecular Weight (Da)",
        y_label="LogP",
        title="LogP vs. Molecular Weight (Lipinski Boundaries)",
        reference_lines=[
            {"axis": "x", "value": 500, "label": "MW = 500"},
            {"axis": "y", "value": 5, "label": "LogP = 5"},
        ],
        width=700,
        height=420,
    )
    html_parts.append("<h3>Lipinski Space</h3>")
    html_parts.append(f'<div class="chart-container">{scatter_chart}</div>')

    # Property heatmap (molecules x properties)
    prop_names = ["MW", "LogP", "HBD", "HBA", "TPSA", "Rot. Bonds"]
    # Normalize each property to 0-1 for heatmap
    raw_matrix = []
    for m in molecules_data:
        raw_matrix.append([m["mw"], m["logp"], m["hbd"], m["hba"], m["tpsa"], m["rotatable_bonds"]])

    # Normalize per column
    n_props = len(prop_names)
    col_min = [min(row[c] for row in raw_matrix) for c in range(n_props)]
    col_max = [max(row[c] for row in raw_matrix) for c in range(n_props)]
    norm_matrix = []
    for row in raw_matrix:
        norm_row = []
        for c in range(n_props):
            rng = col_max[c] - col_min[c]
            norm_row.append((row[c] - col_min[c]) / rng if rng else 0.5)
        norm_matrix.append(norm_row)

    heatmap_labels = [m["name"] for m in molecules_data]
    heatmap = _make_heatmap(
        norm_matrix, heatmap_labels, prop_names,
        title="Normalized Property Heatmap",
        color_scale="cyan",
        width=700,
        height=max(400, len(heatmap_labels) * 28 + 100),
    )
    html_parts.append("<h3>Property Heatmap</h3>")
    html_parts.append(f'<div class="chart-container">{heatmap}</div>')

    # Lipinski compliance table
    html_parts.append("<h3>Lipinski Rule of Five Compliance</h3>")
    html_parts.append("<table><tr><th>Molecule</th><th>MW &le; 500</th>"
                      "<th>LogP &le; 5</th><th>HBD &le; 5</th>"
                      "<th>HBA &le; 10</th><th>Overall</th></tr>")
    for m in molecules_data:
        lip = m["lipinski"]

        def _badge(ok):
            if ok:
                return '<span class="badge badge-success">Pass</span>'
            return '<span class="badge badge-danger">Fail</span>'

        overall_badge = _badge(m["lipinski_pass"])
        html_parts.append(
            f'<tr><td><strong>{m["name"]}</strong></td>'
            f'<td>{_badge(lip["mw_ok"])}</td>'
            f'<td>{_badge(lip["logp_ok"])}</td>'
            f'<td>{_badge(lip["hbd_ok"])}</td>'
            f'<td>{_badge(lip["hba_ok"])}</td>'
            f'<td>{overall_badge}</td></tr>'
        )
    html_parts.append("</table>")

    # QED bar chart
    sorted_by_qed = sorted(molecules_data, key=lambda m: m["qed"], reverse=True)
    qed_labels = [m["name"] for m in sorted_by_qed]
    qed_vals = [m["qed"] for m in sorted_by_qed]
    qed_chart = _make_bar_chart(
        qed_labels, {"QED Score": qed_vals},
        title="Drug-likeness (QED Score)",
        horizontal=True,
        width=700, height=max(300, len(qed_labels) * 30 + 80),
        value_fmt=".3f",
        colors=["#06d6a0"],
    )
    html_parts.append("<h3>Drug-likeness (QED)</h3>")
    html_parts.append(f'<div class="chart-container">{qed_chart}</div>')

    # Flush full report
    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # Return properties as JSON (strip image data URIs to reduce size)
    output = {
        "total": total,
        "lipinski_pass_count": lipinski_pass,
        "lipinski_pass_rate": round(lipinski_rate, 2),
        "avg_mw": round(avg_mw, 2),
        "avg_logp": round(avg_logp, 2),
        "molecules": [
            {k: v for k, v in m.items() if k != "image_data_uri"}
            for m in molecules_data
        ],
    }
    return json.dumps(output)

# ------------------------------------------------------------------
# Task 3: Screen candidates against target profile
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def screen_candidates(
    properties_json: str,
    target_profile: str = "",
) -> str:
    """Screen molecules against a target drug profile and rank candidates.

    Scores each molecule on how well it matches the target profile, computes
    pairwise Tanimoto similarity, and produces a ranked list.

    Args:
        properties_json: JSON from compute_properties.
        target_profile: JSON string with desired property ranges
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5.0]}).

    Returns:
        JSON string with ranked_molecules, similarity_matrix, similarity_labels,
        funnel, and target_profile. Pass the full return value verbatim to
        generate_report along with molecule_dir and properties_json.
    """
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Screening Candidates...</h2>"
                      "<p>Evaluating molecules against the target drug profile.</p>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    molecules = props["molecules"]

    # Default target profile
    if target_profile.strip():
        profile = json.loads(target_profile)
    else:
        profile = {
            "mw": [150, 500],
            "logp": [-0.5, 5.0],
            "hbd": [0, 5],
            "hba": [0, 10],
            "tpsa": [20, 140],
        }

    # --- Screening ---
    funnel_total = len(molecules)
    pass_mw = 0
    pass_logp = 0
    pass_lipinski = 0
    final_candidates = 0

    scored = []
    for m in molecules:
        score = 0
        max_score = 0
        criteria = {}

        # Check each profile criterion
        checks = [
            ("mw", m["mw"]),
            ("logp", m["logp"]),
            ("hbd", m["hbd"]),
            ("hba", m["hba"]),
            ("tpsa", m["tpsa"]),
        ]

        for key, val in checks:
            if key in profile:
                lo, hi = profile[key]
                max_score += 1
                in_range = lo <= val <= hi
                criteria[key] = in_range
                if in_range:
                    score += 1
                    # Bonus: closer to midpoint = higher score
                    mid = (lo + hi) / 2
                    rng = (hi - lo) / 2
                    dist = abs(val - mid) / rng if rng else 0
                    score += max(0, 0.5 * (1 - dist))

        # QED bonus
        score += m["qed"] * 2
        max_score += 2

        # Lipinski bonus
        if m["lipinski_pass"]:
            score += 1
        max_score += 1

        normalized_score = score / max_score if max_score else 0

        # Funnel tracking — cascading filter (each stage requires passing the previous)
        mw_ok = criteria.get("mw", True)
        logp_ok = criteria.get("logp", True)
        if mw_ok:
            pass_mw += 1
            if logp_ok:
                pass_logp += 1
                if m["lipinski_pass"]:
                    pass_lipinski += 1
                    if all(criteria.values()):
                        final_candidates += 1

        scored.append({
            **m,
            "screening_score": round(normalized_score, 4),
            "criteria_met": criteria,
            "all_criteria_met": all(criteria.values()),
        })

    # Sort by score descending
    scored.sort(key=lambda m: m["screening_score"], reverse=True)

    # --- Tanimoto similarity matrix ---
    fps = []
    valid_names = []
    for m in scored:
        mol = Chem.MolFromSmiles(m["smiles"])
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            fps.append(fp)
            valid_names.append(m["name"])

    similarity_matrix = []
    for i in range(len(fps)):
        row = []
        for j in range(len(fps)):
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            row.append(round(sim, 3))
        similarity_matrix.append(row)

    # ---- Build report ----
    html_parts = []
    html_parts.append("<h2>Candidate Screening Results</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(funnel_total), "Total Screened"),
        (str(pass_lipinski), "Lipinski Passes"),
        (str(final_candidates), "All Criteria Met"),
        (f"{scored[0]['screening_score']:.3f}" if scored else "N/A", "Top Score"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Screening funnel
    funnel_stages = [
        {"label": "Total Molecules", "count": funnel_total, "total": funnel_total},
        {"label": "Pass MW Filter", "count": pass_mw, "total": funnel_total},
        {"label": "Pass LogP Filter", "count": pass_logp, "total": funnel_total},
        {"label": "Lipinski Compliant", "count": pass_lipinski, "total": funnel_total},
        {"label": "All Criteria Met", "count": final_candidates, "total": funnel_total},
    ]
    funnel_svg = _make_funnel(
        funnel_stages,
        title="Screening Funnel",
        width=600,
        height=380,
    )
    html_parts.append("<h3>Screening Funnel</h3>")
    html_parts.append(f'<div class="chart-container" style="text-align:center;">{funnel_svg}</div>')

    # Ranked candidates table
    html_parts.append("<h3>Ranked Candidates</h3>")
    html_parts.append(
        "<table><tr><th>Rank</th><th>Molecule</th><th>Score</th>"
        "<th>MW</th><th>LogP</th><th>QED</th><th>Lipinski</th><th>All Criteria</th></tr>"
    )
    for rank, m in enumerate(scored, 1):
        lip_badge = ('<span class="badge badge-success">Pass</span>'
                     if m["lipinski_pass"]
                     else '<span class="badge badge-danger">Fail</span>')
        crit_badge = ('<span class="badge badge-success">Pass</span>'
                      if m["all_criteria_met"]
                      else '<span class="badge badge-danger">Fail</span>')
        # Highlight top 3
        row_style = ' style="background:#ecfeff;font-weight:600;"' if rank <= 3 else ""
        html_parts.append(
            f"<tr{row_style}><td>{rank}</td><td>{m['name']}</td>"
            f"<td>{m['screening_score']:.3f}</td>"
            f"<td>{m['mw']:.1f}</td><td>{m['logp']:.2f}</td>"
            f"<td>{m['qed']:.3f}</td><td>{lip_badge}</td><td>{crit_badge}</td></tr>"
        )
    html_parts.append("</table>")

    # Top 5 candidate cards with structures
    html_parts.append("<h3>Top 5 Candidates</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in scored[:5]:
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(250, 250)) if mol else ""
        badge_class = "badge-success" if m["all_criteria_met"] else "badge-info"
        badge_text = "All Criteria Met" if m["all_criteria_met"] else "Partial Match"
        html_parts.append(
            f'<div class="molecule-card" style="text-align:center;">'
            f'<img src="{img_uri}" style="width:140px;height:140px;object-fit:contain;"/>'
            f'<div style="font-weight:700;margin-top:6px;color:#0e4f6e;font-size:1.05em;">{m["name"]}</div>'
            f'<div style="font-size:0.85em;color:#155e75;margin:4px 0;">Score: {m["screening_score"]:.3f}</div>'
            f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f} | QED: {m["qed"]:.3f}</div>'
            f'<div style="margin-top:4px;"><span class="badge {badge_class}">{badge_text}</span></div>'
            f'</div>'
        )
    html_parts.append("</div>")

    # Tanimoto similarity heatmap
    if similarity_matrix:
        sim_heatmap = _make_heatmap(
            similarity_matrix, valid_names, valid_names,
            title="Pairwise Tanimoto Similarity (Morgan Fingerprints)",
            color_scale="cyan",
            width=700,
            height=max(500, len(valid_names) * 32 + 100),
        )
        html_parts.append("<h3>Chemical Similarity</h3>")
        html_parts.append(f'<div class="chart-container">{sim_heatmap}</div>')

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    output = {
        "ranked_molecules": scored,
        "similarity_matrix": similarity_matrix,
        "similarity_labels": valid_names,
        "funnel": funnel_stages,
        "target_profile": profile,
    }
    return json.dumps(output)

def _parse_screening_json(screening_json: str) -> dict:
    """Parse screening JSON from screen_candidates, with safe defaults.

    The agent must pass the exact tool return value. Partial or hand-built JSON
    is tolerated for optional similarity fields only.
    """
    screening = json.loads(screening_json)
    if "ranked_molecules" not in screening:
        raise ValueError(
            "screening_json must be the exact JSON string returned by "
            "screen_candidates (missing 'ranked_molecules'). Do not construct, "
            "truncate, or summarize tool output."
        )
    screening.setdefault("similarity_matrix", [])
    screening.setdefault("similarity_labels", [])
    return screening

# ------------------------------------------------------------------
# Task 4: Generate final comprehensive report
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def generate_report(
    molecule_dir: flyte.io.Dir,
    properties_json: str,
    screening_json: str,
) -> str:
    """Generate a comprehensive drug screening report.

    Produces an executive summary, top candidate spotlight cards, property
    distributions, chemical diversity analysis, and final recommendation.

    Args:
        molecule_dir: Directory from load_molecules.
        properties_json: JSON from compute_properties.
        screening_json: Exact verbatim JSON string returned by screen_candidates
            (must include ranked_molecules, similarity_matrix, similarity_labels).
            Do not construct or summarize this payload yourself.

    Returns:
        JSON summary with total_screened, lipinski_passes, all_criteria_met,
        top_candidate, top_score, and top_3 ranked molecules.
    """
    from rdkit import Chem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Generating Final Report...</h2>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    screening = _parse_screening_json(screening_json)
    ranked = screening["ranked_molecules"]
    sim_matrix = screening["similarity_matrix"]
    sim_labels = screening["similarity_labels"]

    total = props["total"]
    lipinski_pass = props["lipinski_pass_count"]
    all_criteria = sum(1 for m in ranked if m["all_criteria_met"])
    top = ranked[0] if ranked else None

    html_parts = []

    # --- Executive Summary ---
    html_parts.append("<h2>Drug Molecule Screening Report</h2>")
    top_name = top["name"] if top else "N/A"
    top_score = f'{top["screening_score"]:.3f}' if top else "N/A"
    html_parts.append(
        f'<div class="card">'
        f'<h3 style="margin-top:0;color:#0e4f6e;">Executive Summary</h3>'
        f'<p style="font-size:1.05em;">'
        f'<strong>{total}</strong> molecules were screened against the target drug profile. '
        f'<strong>{lipinski_pass}</strong> passed Lipinski\'s Rule of Five, and '
        f'<strong>{all_criteria}</strong> met all screening criteria. '
        f'The top candidate is <strong style="color:#0891b2;">{top_name}</strong> '
        f'with a screening score of <strong>{top_score}</strong>.</p>'
        f'</div>'
    )

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Molecules Screened"),
        (str(lipinski_pass), "Lipinski Passes"),
        (str(all_criteria), "All Criteria Met"),
        (top_score, "Top Score"),
        (f'{props["avg_mw"]:.0f} Da', "Avg. Molecular Weight"),
        (f'{props["avg_logp"]:.2f}', "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # --- Top 3 Candidate Spotlights ---
    html_parts.append("<h2>Top Candidate Spotlights</h2>")

    for rank, m in enumerate(ranked[:3], 1):
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(300, 300)) if mol else ""

        medal = ["gold", "silver", "#cd7f32"][rank - 1]
        medal_emoji = ["1st", "2nd", "3rd"][rank - 1]

        lip_badges = ""
        for rule, key in [("MW", "mw_ok"), ("LogP", "logp_ok"),
                          ("HBD", "hbd_ok"), ("HBA", "hba_ok")]:
            ok = m["lipinski"].get(key, False)
            cls = "badge-success" if ok else "badge-danger"
            lip_badges += f'<span class="badge {cls}" style="margin:2px;">{rule}</span> '

        html_parts.append(
            f'<div class="molecule-card" style="display:flex;gap:20px;align-items:flex-start;flex-wrap:wrap;">'
            f'<div style="text-align:center;min-width:180px;">'
            f'<div style="font-size:1.6em;font-weight:800;color:{medal};">{medal_emoji}</div>'
            f'<img src="{img_uri}" style="width:200px;height:200px;object-fit:contain;border-radius:8px;'
            f'border:2px solid #a5f3fc;"/>'
            f'<div style="font-weight:700;font-size:1.1em;color:#0e4f6e;margin-top:8px;">{m["name"]}</div>'
            f'</div>'
            f'<div style="flex:1;min-width:280px;">'
            f'<table style="margin:0;">'
            f'<tr><td><strong>SMILES</strong></td><td style="font-family:monospace;font-size:0.8em;word-break:break-all;">{m["smiles"]}</td></tr>'
            f'<tr><td><strong>Screening Score</strong></td><td style="font-weight:700;color:#0891b2;font-size:1.1em;">{m["screening_score"]:.3f}</td></tr>'
            f'<tr><td><strong>Molecular Weight</strong></td><td>{m["mw"]:.1f} Da</td></tr>'
            f'<tr><td><strong>LogP</strong></td><td>{m["logp"]:.2f}</td></tr>'
            f'<tr><td><strong>H-Bond Donors</strong></td><td>{m["hbd"]}</td></tr>'
            f'<tr><td><strong>H-Bond Acceptors</strong></td><td>{m["hba"]}</td></tr>'
            f'<tr><td><strong>TPSA</strong></td><td>{m["tpsa"]:.1f} A&sup2;</td></tr>'
            f'<tr><td><strong>Rotatable Bonds</strong></td><td>{m["rotatable_bonds"]}</td></tr>'
            f'<tr><td><strong>QED</strong></td><td>{m["qed"]:.4f}</td></tr>'
            f'<tr><td><strong>Lipinski Compliance</strong></td><td>{lip_badges}</td></tr>'
            f'</table>'
            f'</div>'
            f'</div>'
        )

    # --- Property Distribution (box-plot style as bars with min/max/median) ---
    html_parts.append("<h2>Property Distributions</h2>")

    prop_keys = [("mw", "Molecular Weight (Da)"), ("logp", "LogP"),
                 ("tpsa", "TPSA"), ("qed", "QED Score")]
    for key, label in prop_keys:
        vals = sorted([m[key] for m in ranked])
        n = len(vals)
        if n == 0:
            continue
        v_min = vals[0]
        v_max = vals[-1]
        median = vals[n // 2] if n % 2 == 1 else (vals[n // 2 - 1] + vals[n // 2]) / 2
        q1 = vals[n // 4] if n >= 4 else v_min
        q3 = vals[3 * n // 4] if n >= 4 else v_max

        # Simple horizontal box-plot as SVG
        box_w = 500
        box_h = 50
        margin_l = 10
        v_range = v_max - v_min or 1

        def sx(v):
            return margin_l + ((v - v_min) / v_range) * (box_w - 2 * margin_l)

        box_svg = (
            f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {box_w} {box_h}" '
            f'style="width:100%;max-width:{box_w}px;height:auto;">'
            f'<rect width="{box_w}" height="{box_h}" fill="#fff" rx="4"/>'
            # Whisker line
            f'<line x1="{sx(v_min):.1f}" y1="25" x2="{sx(v_max):.1f}" y2="25" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Min whisker
            f'<line x1="{sx(v_min):.1f}" y1="18" x2="{sx(v_min):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Max whisker
            f'<line x1="{sx(v_max):.1f}" y1="18" x2="{sx(v_max):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # IQR box
            f'<rect x="{sx(q1):.1f}" y="14" width="{sx(q3) - sx(q1):.1f}" height="22" '
            f'fill="#a5f3fc" stroke="#0891b2" stroke-width="1.5" rx="3"/>'
            # Median line
            f'<line x1="{sx(median):.1f}" y1="12" x2="{sx(median):.1f}" y2="38" '
            f'stroke="#0e4f6e" stroke-width="2"/>'
            # Labels
            f'<text x="{sx(v_min):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_min:.1f}</text>'
            f'<text x="{sx(median):.1f}" y="10" text-anchor="middle" font-size="9" fill="#0e4f6e" font-weight="600">{median:.1f}</text>'
            f'<text x="{sx(v_max):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_max:.1f}</text>'
            f'</svg>'
        )
        html_parts.append(
            f'<div style="margin:8px 0;"><strong style="color:#155e75;">{label}</strong>'
            f'<div class="chart-container" style="padding:8px;">{box_svg}</div></div>'
        )

    # --- Chemical Diversity ---
    html_parts.append("<h2>Chemical Diversity Analysis</h2>")

    if sim_matrix and len(sim_matrix) > 1:
        # Compute average pairwise similarity (off-diagonal)
        n_mols = len(sim_matrix)
        off_diag = []
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                off_diag.append(sim_matrix[i][j])

        avg_sim = sum(off_diag) / len(off_diag) if off_diag else 0
        max_sim = max(off_diag) if off_diag else 0
        min_sim = min(off_diag) if off_diag else 0

        # Find most similar pair
        best_i, best_j = 0, 1
        best_val = 0
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                if sim_matrix[i][j] > best_val:
                    best_val = sim_matrix[i][j]
                    best_i, best_j = i, j

        html_parts.append('<div class="stat-grid">')
        html_parts.append(
            f'<div class="stat"><div class="value">{avg_sim:.3f}</div>'
            f'<div class="label">Avg. Pairwise Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{min_sim:.3f}</div>'
            f'<div class="label">Min Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{max_sim:.3f}</div>'
            f'<div class="label">Max Similarity</div></div>'
        )
        html_parts.append("</div>")

        diversity_text = "highly diverse" if avg_sim < 0.3 else "moderately diverse" if avg_sim < 0.5 else "relatively similar"
        html_parts.append(
            f'<div class="note">'
            f'The library is <strong>{diversity_text}</strong> (avg. Tanimoto = {avg_sim:.3f}). '
            f'The most similar pair is <strong>{sim_labels[best_i]}</strong> and '
            f'<strong>{sim_labels[best_j]}</strong> (similarity = {best_val:.3f}).</div>'
        )

    # --- Recommendation ---
    html_parts.append("<h2>Recommendation</h2>")
    if top:
        html_parts.append(
            f'<div class="card">'
            f'<h3 style="margin-top:0;color:#0891b2;">Top Candidate: {top["name"]}</h3>'
            f'<p>Based on the virtual screening analysis, <strong>{top["name"]}</strong> '
            f'achieved the highest composite screening score of <strong>{top["screening_score"]:.3f}</strong>. '
        )

        reasons = []
        if top["lipinski_pass"]:
            reasons.append("full Lipinski Rule of Five compliance")
        if top["qed"] > 0.5:
            reasons.append(f"high drug-likeness (QED = {top['qed']:.3f})")
        if top.get("all_criteria_met"):
            reasons.append("all target profile criteria met")
        if top["mw"] <= 500:
            reasons.append(f"favorable molecular weight ({top['mw']:.1f} Da)")

        if reasons:
            html_parts.append(
                f'This candidate stands out due to: {", ".join(reasons)}.</p>'
            )
        else:
            html_parts.append("</p>")

        # Runner-up mentions
        if len(ranked) >= 2:
            html_parts.append(
                f'<p style="font-size:0.9em;color:#6c757d;">Runner-up candidates: '
            )
            runners = []
            for m in ranked[1:4]:
                runners.append(f'{m["name"]} (score: {m["screening_score"]:.3f})')
            html_parts.append(", ".join(runners) + ".</p>")

        html_parts.append("</div>")

    # Final note
    html_parts.append(
        '<div class="note">'
        "This is a virtual screening analysis. All candidates should undergo "
        "further computational validation (molecular dynamics, docking) and "
        "experimental testing before advancing to clinical trials.</div>"
    )

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # JSON summary
    summary = {
        "total_screened": total,
        "lipinski_passes": lipinski_pass,
        "all_criteria_met": all_criteria,
        "top_candidate": top["name"] if top else None,
        "top_score": top["screening_score"] if top else None,
        "top_3": [
            {"name": m["name"], "score": m["screening_score"]}
            for m in ranked[:3]
        ],
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Agent
# ------------------------------------------------------------------

# {{docs-fragment agent}}
SCREENING_AGENT_INSTRUCTIONS = """\
You are a medicinal chemistry screening strategist. You orchestrate a virtual \
screening pipeline using durable Flyte tools. You NEVER invent molecular \
properties — only RDKit tools compute them.

Workflow:
1. If target_profile is not provided in the user message, derive a JSON \
target_profile from the therapeutic brief. Valid keys: mw, logp, hbd, hba, tpsa \
(each [min, max]). Ground choices in oral bioavailability / kinase / CNS rules \
as appropriate to the brief.
2. First pass (always): load_molecules → compute_properties → \
screen_candidates → generate_report. Pass tool outputs between steps exactly \
(molecule_dir from load_molecules into compute_properties and generate_report; \
properties_json from compute_properties into screen_candidates and \
generate_report; screening_json must be the complete, unmodified string \
returned by screen_candidates — never rebuild or summarize JSON yourself).
3. Read the JSON summary returned by generate_report. Reflect:
   - If all_criteria_met == 0: relax exactly ONE profile bound by ~10–20% \
and re-run screen_candidates then generate_report only, reusing the same \
molecule_dir and properties_json from the first pass.
   - If all molecules pass but diversity is a stated goal: note high similarity \
in your summary; do not re-run unless brief asks for stricter filters.
   - Maximum ONE rescreen iteration.
4. Finish with plain text: top candidate, rationale tied to computed metrics \
from the tool JSON, funnel interpretation, and suggested next steps (docking, \
ADMET lab tests).

If the user supplies an explicit target_profile JSON, use it as-is.

Do NOT ask the user for SMILES or molecule lists when molecules_json is empty — \
the default library is loaded automatically.
"""

screening_agent = Agent(
    name="drug-screening-agent",
    instructions=SCREENING_AGENT_INSTRUCTIONS,
    model=MODEL,
    tools=[
        load_molecules,
        compute_properties,
        screen_candidates,
        generate_report,
    ],
    max_turns=12,
)
# {{/docs-fragment agent}}

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def pipeline(
    brief: str = "Screen the default drug library for orally bioavailable small molecules.",
    molecules_json: str = "",
    target_profile: str = "",
) -> str:
    """Agentic virtual drug molecule screening pipeline.

    A medicinal-chemistry agent interprets the screening brief, derives or
    applies a target profile, orchestrates the RDKit screening stages, and
    optionally re-screens when funnel results are too narrow.

    Args:
        brief: Natural-language therapeutic goal (e.g. oral kinase inhibitors,
            CNS-penetrant small molecules).
        molecules_json: JSON mapping molecule names to SMILES strings.
            Defaults to a curated library of ~15 well-known drugs.
        target_profile: Optional JSON with desired property ranges that
            overrides agent-derived criteria
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5]}).

    Returns:
        Agent summary with screening rationale and key results.
    """
    prompt_parts = [
        f"Screening brief: {brief}",
        'Use molecules_json="" for the built-in default library unless provided below.',
        "Compose the four stage tools in order: load_molecules → compute_properties "
        "→ screen_candidates → generate_report. Pass each tool's full return value "
        "verbatim to the next step (especially screening_json). Re-run "
        "screen_candidates and generate_report at most once if the funnel is too narrow.",
    ]
    if molecules_json.strip():
        prompt_parts.append(f"molecules_json: {molecules_json}")
    if target_profile.strip():
        prompt_parts.append(f"Use this target_profile exactly: {target_profile}")

    result = await screening_agent.run.aio("\n".join(prompt_parts))
    return result.summary or result.error or ""

# {{/docs-fragment pipeline}}

# ------------------------------------------------------------------
# Rescreen demo — tight profile + explicit rescreen instructions
# ------------------------------------------------------------------

# Initial profile is deliberately strict (narrow MW + low LogP cap) so
# all_criteria_met is typically 0 on the default library; the brief then
# forces a single rescreen with a widened LogP window.
RESCREEN_DEMO_TARGET_PROFILE = (
    '{"mw": [150, 200], "logp": [-0.5, 1.0], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_TARGET_PROFILE_RESCREEN = (
    '{"mw": [150, 200], "logp": [-0.5, 3.5], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_BRIEF = f"""\
Two-round agentic screening demo on the default library.

**Round 1 (strict profile):** load_molecules → compute_properties → \
screen_candidates → generate_report using the initial target_profile exactly.

**Round 2 (required — do not skip):** call screen_candidates then generate_report \
again, reusing the same molecule_dir and properties_json from round 1, with this \
relaxed target_profile (wider LogP window only): \
{RESCREEN_DEMO_TARGET_PROFILE_RESCREEN}

Pass every tool return value verbatim to the next step. After both rounds, \
summarize how the funnel and top candidates changed between round 1 and round 2."""

# {{docs-fragment rescreen_demo}}
@env.task(report=True)
async def rescreen_demo() -> str:
    """Example run with a two-round execution graph (rescreen).

    Round 1 uses a strict CNS-like profile; round 2 always re-runs
    screen_candidates and generate_report with a widened LogP window,
    reusing cached molecule_dir and properties_json.
    """
    return await pipeline(
        brief=RESCREEN_DEMO_BRIEF,
        target_profile=RESCREEN_DEMO_TARGET_PROFILE,
    )

# {{/docs-fragment rescreen_demo}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/drug_molecule_screening/drug_molecule_screening.py*

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/drug_molecule_screening):

```
cd v2/tutorials/drug_molecule_screening
uv run --script drug_molecule_screening.py
```

Pass a natural-language brief (the agent derives the target profile):

```
flyte run drug_molecule_screening.py pipeline \
  --brief "Find oral kinase inhibitor candidates under 400 Da with moderate LogP"
```

Or pass an explicit target profile to skip agent-derived criteria:

```
flyte run drug_molecule_screening.py pipeline \
  --target_profile '{"mw": [100, 400], "logp": [-0.5, 4.0]}'
```

### Two-round rescreen demo (complex execution graph)

The `rescreen_demo` task always runs two screening rounds: a strict first pass (`load_molecules` → `compute_properties` → `screen_candidates` → `generate_report`), then a second `screen_candidates` → `generate_report` with a widened LogP window reusing the same `molecule_dir` and `properties_json`. The Flyte UI shows six stage tasks instead of four.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.4",
#    "litellm",
#    "rdkit",
#    "numpy",
#    "scikit-learn",
#    "pillow",
# ]
# main = "pipeline"
# params = ""
# ///
"""Virtual drug molecule screening — compute properties, apply Lipinski filters, rank candidates."""

import base64
import io
import json
import logging
import math
import os
import tempfile

import flyte
import flyte.io
import flyte.report
from flyte.ai.agents import Agent, tool

MODEL = os.getenv("DRUG_SCREENING_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="drug-molecule-screening", pre=True).with_apt_packages(
    "libxrender1", "libxext6", "libexpat1",
)

env = flyte.TaskEnvironment(
    name="drug-molecule-screening",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Default molecule library — real SMILES for well-known drugs
# ------------------------------------------------------------------

DEFAULT_MOLECULES = {
    "Aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
    "Ibuprofen": "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "Penicillin G": "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C",
    "Metformin": "CN(C)C(=N)NC(=N)N",
    "Paracetamol": "CC(=O)NC1=CC=C(C=C1)O",
    "Diazepam": "ClC1=CC2=C(C=C1)N(C(=O)CN=C2C3=CC=CC=C3)C",
    "Omeprazole": "CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=CC=CC=C3N2",
    "Atorvastatin": "CC(C)C1=C(C(=C(N1CCC(CC(CC(=O)O)O)O)C2=CC=C(C=C2)F)C3=CC=CC=C3)C(=O)NC4=CC=CC=C4",
    "Methotrexate": "CN(CC1=CN=C2N=C(N=C(N)C2=N1)N)C3=CC=C(C=C3)C(=O)NC(CCC(=O)O)C(=O)O",
    "Doxorubicin": "CC1C(C(CC(O1)OC2CC(CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
    "Tamoxifen": "CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3",
    "Lopinavir": "CC1=C(C(=CC=C1)C)OCC(=O)NC(CC2=CC=CC=C2)C(CC(CC3=CC=CC=C3)NC(=O)C(C(C)C)N4CCCNC4=O)O",
    "Remdesivir": "CCC(CC)COC(=O)C(C)NP(=O)(OCC1C(C(C(O1)C2=CC=C3N2N=CN=C3N)O)O)OC4=CC=CC=C4",
    "Erlotinib": "COCCOC1=CC2=C(C=C1OCCOC)C(=NC=N2)NC3=CC=CC(=C3)C#C",
}

# ------------------------------------------------------------------
# Report styling — pharma blue/cyan theme
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #0e4f6e; border-bottom: 2px solid #0891b2; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #155e75; margin-top: 20px; }
  .report .card { background: #ecfeff; border: 1px solid #a5f3fc; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #cffafe; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0e4f6e; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0e4f6e; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #cffafe; }
  .report tr:nth-child(even) { background: #ecfeff; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d1fae5; color: #065f46; }
  .report .badge-danger { background: #fee2e2; color: #991b1b; }
  .report .badge-info { background: #cffafe; color: #155e75; }
  .report .chart-container { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 16px 0; }
  .report .note { background: #ecfeff; border-left: 4px solid #0891b2; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .molecule-card { background: #fff; border: 1px solid #cffafe; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .molecule-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 12px; margin: 16px 0; }
  .report .funnel { text-align: center; margin: 24px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers
# ------------------------------------------------------------------

def _mol_to_data_uri(mol, size: tuple[int, int] = (300, 300)) -> str:
    """Convert an RDKit molecule to a PNG base64 data URI."""
    from rdkit.Chem import Draw

    img = Draw.MolToImage(mol, size=size)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f"data:image/png;base64,{b64}"

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 340,
    y_max_cap: float | None = None,
    horizontal: bool = False,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels.
        series: Dict mapping series name to list of values.
        title: Chart title.
        colors: Colors for each series.
        width/height: SVG dimensions.
        y_max_cap: Cap the y-axis at this value.
        horizontal: If True, draw horizontal bars.
        value_fmt: Format string for value labels.

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#a5f3fc", "#155e75"]
    colors = colors or default_colors

    if horizontal:
        return _make_horizontal_bar_chart(labels, series, title, colors, width, height, value_fmt)

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:{value_fmt}}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="9" fill="#1a1a2e">'
                f'{val:{value_fmt}}</text>'
            )
        # Truncate long labels
        disp_label = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 16}" '
            f'text-anchor="middle" font-size="10" fill="#6c757d" '
            f'transform="rotate(-35, {gx + n_series * bar_width / 2:.1f}, {mt + ch + 16})">'
            f'{disp_label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Legend
    if n_series > 1:
        lx = ml + cw - len(series) * 100
        for si, name in enumerate(series):
            color = colors[si % len(colors)]
            svg.append(
                f'<rect x="{lx + si * 100}" y="{mt + ch + 40}" width="12" '
                f'height="12" rx="2" fill="{color}"/>'
            )
            svg.append(
                f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 51}" font-size="11" '
                f'fill="#1a1a2e">{name}</text>'
            )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_horizontal_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 400,
    value_fmt: str = ".1f",
) -> str:
    """Generate an SVG horizontal bar chart (sorted)."""
    default_colors = ["#0891b2", "#0e4f6e", "#06d6a0"]
    colors = colors or default_colors

    n = len(labels)
    row_height = max(22, min(35, (height - 80) // max(n, 1)))
    actual_height = max(height, 80 + n * row_height)
    ml, mr, mt, mb = 120, 60, 40, 20
    cw = width - ml - mr
    ch = actual_height - mt - mb

    # Use first series
    first_key = list(series.keys())[0]
    vals = series[first_key]
    x_max = max(vals) * 1.15 if vals else 1

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {actual_height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{actual_height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    bar_h = row_height * 0.65
    for i, (label, val) in enumerate(zip(labels, vals)):
        y = mt + i * row_height
        bw = (val / x_max) * cw if x_max else 0
        color = colors[i % len(colors)]
        # Label
        disp = label if len(label) <= 14 else label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + bar_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#1a1a2e">{disp}</text>'
        )
        # Bar
        svg.append(
            f'<rect x="{ml}" y="{y:.1f}" width="{bw:.1f}" height="{bar_h:.1f}" '
            f'fill="{color}" rx="3"/>'
        )
        # Value
        svg.append(
            f'<text x="{ml + bw + 6:.1f}" y="{y + bar_h / 2 + 4:.1f}" '
            f'font-size="11" fill="#0e4f6e" font-weight="600">{val:{value_fmt}}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_heatmap(
    matrix: list[list[float]],
    row_labels: list[str],
    col_labels: list[str],
    title: str = "",
    color_scale: str = "cyan",
    width: int = 700,
    height: int = 500,
    value_fmt: str = ".2f",
) -> str:
    """Generate an SVG heatmap.

    Args:
        matrix: 2D list of values (rows x cols).
        row_labels: Labels for rows.
        col_labels: Labels for columns.
        title: Chart title.
        color_scale: Color scheme ("cyan", "red", "green").
        width/height: SVG dimensions.
        value_fmt: Format string for cell values.

    Returns:
        SVG string.
    """
    if not matrix or not matrix[0]:
        return ""

    n_rows = len(matrix)
    n_cols = len(matrix[0])

    ml, mr, mt, mb = 110, 20, 70, 20
    cw = width - ml - mr
    ch = height - mt - mb
    cell_w = cw / n_cols
    cell_h = ch / n_rows

    # Flatten to find range
    flat = [v for row in matrix for v in row]
    v_min = min(flat)
    v_max = max(flat)
    v_range = v_max - v_min or 1

    def color_for(v):
        t = (v - v_min) / v_range
        if color_scale == "cyan":
            # White to deep teal
            r = int(255 - t * (255 - 14))
            g = int(255 - t * (255 - 79))
            b = int(255 - t * (255 - 110))
        elif color_scale == "red":
            r = int(255 - t * 50)
            g = int(255 - t * 200)
            b = int(255 - t * 200)
        else:  # green
            r = int(255 - t * 200)
            g = int(255 - t * 50)
            b = int(255 - t * 200)
        return f"rgb({r},{g},{b})"

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Column labels (rotated)
    for ci, label in enumerate(col_labels):
        x = ml + ci * cell_w + cell_w / 2
        disp = label if len(label) <= 12 else label[:10] + ".."
        svg.append(
            f'<text x="{x:.1f}" y="{mt - 8}" text-anchor="end" font-size="10" '
            f'fill="#1a1a2e" transform="rotate(-45, {x:.1f}, {mt - 8})">{disp}</text>'
        )

    # Row labels + cells
    for ri, (row_label, row_vals) in enumerate(zip(row_labels, matrix)):
        y = mt + ri * cell_h
        disp = row_label if len(row_label) <= 14 else row_label[:12] + ".."
        svg.append(
            f'<text x="{ml - 8}" y="{y + cell_h / 2 + 4:.1f}" text-anchor="end" '
            f'font-size="10" fill="#1a1a2e">{disp}</text>'
        )
        for ci, val in enumerate(row_vals):
            x = ml + ci * cell_w
            fill = color_for(val)
            svg.append(
                f'<rect x="{x:.1f}" y="{y:.1f}" width="{cell_w:.1f}" '
                f'height="{cell_h:.1f}" fill="{fill}" stroke="#fff" stroke-width="1"/>'
            )
            # Text color: dark on light, light on dark
            t = (val - v_min) / v_range
            txt_color = "#fff" if t > 0.55 else "#1a1a2e"
            # Only show text if cells are large enough
            if cell_w > 30 and cell_h > 18:
                svg.append(
                    f'<text x="{x + cell_w / 2:.1f}" y="{y + cell_h / 2 + 4:.1f}" '
                    f'text-anchor="middle" font-size="9" fill="{txt_color}">'
                    f'{val:{value_fmt}}</text>'
                )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_scatter_plot(
    points: list[dict],
    x_label: str = "MW",
    y_label: str = "LogP",
    title: str = "",
    reference_lines: list[dict] | None = None,
    width: int = 700,
    height: int = 400,
) -> str:
    """Generate an SVG scatter plot.

    Args:
        points: List of dicts with "x", "y", "label" keys.
        x_label/y_label: Axis labels.
        title: Chart title.
        reference_lines: List of dicts with "axis" ("x"/"y"), "value", "label".
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not points:
        return ""

    ml, mr, mt, mb = 60, 30, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [p["x"] for p in points]
    y_vals = [p["y"] for p in points]
    x_min, x_max = min(x_vals) * 0.9, max(x_vals) * 1.1
    y_min, y_max = min(y_vals) - 1, max(y_vals) + 1

    # Extend ranges to include reference lines
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                x_max = max(x_max, rl["value"] * 1.1)
            else:
                y_max = max(y_max, rl["value"] * 1.1)

    x_range = x_max - x_min or 1
    y_range = y_max - y_min or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min) / y_range * ch

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid
    for i in range(6):
        y_tick = y_min + y_range * i / 5
        py = sy(y_tick)
        svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e0f2fe" stroke-width="1"/>'
        )
        svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.1f}</text>'
        )

    for i in range(6):
        x_tick = x_min + x_range * i / 5
        px = sx(x_tick)
        svg.append(
            f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
            f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
        )

    # Axes
    svg.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )
    svg.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#94a3b8" stroke-width="1.5"/>'
    )

    # Reference lines (Lipinski boundaries)
    if reference_lines:
        for rl in reference_lines:
            if rl["axis"] == "x":
                px = sx(rl["value"])
                svg.append(
                    f'<line x1="{px:.1f}" y1="{mt}" x2="{px:.1f}" y2="{mt + ch}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{px + 4:.1f}" y="{mt + 14}" font-size="10" '
                    f'fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )
            else:
                py = sy(rl["value"])
                svg.append(
                    f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
                    f'stroke="#ef4444" stroke-width="1.5" stroke-dasharray="6,4"/>'
                )
                svg.append(
                    f'<text x="{ml + cw - 4:.1f}" y="{py - 6:.1f}" text-anchor="end" '
                    f'font-size="10" fill="#ef4444" font-weight="600">{rl["label"]}</text>'
                )

    # Drug-like zone shading (MW<=500 and LogP<=5 quadrant)
    if reference_lines:
        mw_line = next((rl for rl in reference_lines if rl["axis"] == "x"), None)
        logp_line = next((rl for rl in reference_lines if rl["axis"] == "y"), None)
        if mw_line and logp_line:
            zx1 = sx(x_min)
            zx2 = sx(min(mw_line["value"], x_max))
            zy1 = sy(min(logp_line["value"], y_max))
            zy2 = sy(y_min)
            svg.append(
                f'<rect x="{zx1:.1f}" y="{zy1:.1f}" '
                f'width="{zx2 - zx1:.1f}" height="{zy2 - zy1:.1f}" '
                f'fill="#0891b2" opacity="0.06" rx="4"/>'
            )
            svg.append(
                f'<text x="{zx1 + 8:.1f}" y="{zy2 - 8:.1f}" font-size="11" '
                f'fill="#0891b2" font-weight="600" opacity="0.6">Drug-like Zone</text>'
            )

    # Points
    point_colors = ["#0891b2", "#0e4f6e", "#06d6a0", "#155e75", "#0284c7",
                    "#059669", "#0d9488", "#0369a1", "#047857", "#115e59",
                    "#0c4a6e", "#064e3b", "#1e3a5f", "#134e4a", "#075985"]
    for i, pt in enumerate(points):
        px, py = sx(pt["x"]), sy(pt["y"])
        color = point_colors[i % len(point_colors)]
        svg.append(
            f'<circle cx="{px:.1f}" cy="{py:.1f}" r="5" fill="{color}" '
            f'stroke="#fff" stroke-width="1.5" opacity="0.85"/>'
        )
        # Label offset to avoid overlap
        offset_x = 8
        offset_y = -8 if i % 2 == 0 else 14
        label = pt["label"] if len(pt["label"]) <= 12 else pt["label"][:10] + ".."
        svg.append(
            f'<text x="{px + offset_x:.1f}" y="{py + offset_y:.1f}" '
            f'font-size="9" fill="#1a1a2e">{label}</text>'
        )

    # Title
    if title:
        svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#0e4f6e">{title}</text>'
        )

    # Axis labels
    if x_label:
        svg.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        svg.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

def _make_funnel(
    stages: list[dict],
    title: str = "",
    width: int = 600,
    height: int = 400,
) -> str:
    """Generate an SVG funnel visualization.

    Args:
        stages: List of dicts with "label", "count", "total" keys.
        title: Chart title.
        width/height: SVG dimensions.

    Returns:
        SVG string.
    """
    if not stages:
        return ""

    n = len(stages)
    mt = 50
    mb = 20
    available_h = height - mt - mb
    stage_h = available_h / n
    cx = width / 2

    # Color gradient from light cyan to deep teal
    colors = []
    for i in range(n):
        t = i / max(n - 1, 1)
        r = int(207 - t * (207 - 14))
        g = int(250 - t * (250 - 79))
        b = int(254 - t * (254 - 110))
        colors.append(f"rgb({r},{g},{b})")

    svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    if title:
        svg.append(
            f'<text x="{cx}" y="28" text-anchor="middle" '
            f'font-size="16" font-weight="700" fill="#0e4f6e">{title}</text>'
        )

    max_count = stages[0]["count"] if stages else 1
    max_width = width * 0.75

    for i, stage in enumerate(stages):
        y_top = mt + i * stage_h
        y_bot = y_top + stage_h

        # Width proportional to count
        w_top = max_width * (stage["count"] / max_count) if i == 0 else prev_w_bot
        if i < n - 1:
            w_bot = max_width * (stages[i + 1]["count"] / max_count)
        else:
            w_bot = max_width * (stage["count"] / max_count) * 0.7

        prev_w_bot = w_bot

        # Trapezoid
        x1_top = cx - w_top / 2
        x2_top = cx + w_top / 2
        x1_bot = cx - w_bot / 2
        x2_bot = cx + w_bot / 2

        svg.append(
            f'<polygon points="{x1_top:.1f},{y_top:.1f} {x2_top:.1f},{y_top:.1f} '
            f'{x2_bot:.1f},{y_bot:.1f} {x1_bot:.1f},{y_bot:.1f}" '
            f'fill="{colors[i]}" stroke="#fff" stroke-width="2"/>'
        )

        # Text: dark on light, white on dark
        t = i / max(n - 1, 1)
        txt_color = "#0e4f6e" if t < 0.5 else "#fff"
        y_mid = (y_top + y_bot) / 2

        svg.append(
            f'<text x="{cx}" y="{y_mid - 4:.1f}" text-anchor="middle" '
            f'font-size="13" font-weight="600" fill="{txt_color}">{stage["label"]}</text>'
        )
        svg.append(
            f'<text x="{cx}" y="{y_mid + 14:.1f}" text-anchor="middle" '
            f'font-size="12" fill="{txt_color}" opacity="0.85">'
            f'{stage["count"]} / {stage["total"]}</text>'
        )

    svg.append("</svg>")
    return "\n".join(svg)

# ------------------------------------------------------------------
# Task 1: Load and validate molecules
# ------------------------------------------------------------------

@tool
@env.task(cache="auto")
async def load_molecules(
    molecules_json: str = "",
) -> flyte.io.Dir:
    """Parse SMILES strings, validate with RDKit, generate 2D depictions.

    Args:
        molecules_json: JSON string mapping molecule names to SMILES.
            Defaults to a curated library of ~15 well-known drugs.

    Returns:
        flyte.io.Dir containing molecule data (JSON + PNG depictions).
        Pass this directory to compute_properties and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Draw

    if molecules_json.strip():
        molecules = json.loads(molecules_json)
    else:
        molecules = DEFAULT_MOLECULES

    out_dir = tempfile.mkdtemp(prefix="mol_library_")
    results = []
    valid_count = 0
    invalid_count = 0

    log.info(f"Parsing {len(molecules)} molecules...")

    for name, smiles in molecules.items():
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            log.warning(f"  [INVALID] {name}: {smiles}")
            invalid_count += 1
            continue

        valid_count += 1

        # Generate 2D depiction as PNG
        img = Draw.MolToImage(mol, size=(300, 300))
        img_path = os.path.join(out_dir, f"{name.replace(' ', '_')}.png")
        img.save(img_path)

        results.append({
            "name": name,
            "smiles": smiles,
            "valid": True,
            "image_file": os.path.basename(img_path),
        })

    # Save molecule manifest
    manifest = {
        "total": len(molecules),
        "valid": valid_count,
        "invalid": invalid_count,
        "molecules": results,
    }
    manifest_path = os.path.join(out_dir, "manifest.json")
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    log.info(f"Loaded {valid_count} valid molecules ({invalid_count} invalid)")

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Task 2: Compute physicochemical properties
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def compute_properties(
    molecule_dir: flyte.io.Dir,
) -> str:
    """Compute drug-likeness properties for all molecules.

    Computes MW, LogP, HBD, HBA, TPSA, rotatable bonds, formal charge,
    ring count, QED, and Lipinski Rule of Five compliance.

    Args:
        molecule_dir: Directory from load_molecules.

    Returns:
        JSON string with all computed properties. Pass to screen_candidates
        and generate_report.
    """
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Lipinski
    from rdkit.Chem.QED import qed

    # --- Loading report ---
    await flyte.report.replace.aio(
        _wrap_report("<h2>Computing Molecular Properties...</h2>"
                      "<p>Analyzing physicochemical descriptors for all molecules.</p>"),
        do_flush=True,
    )

    mol_dir = await molecule_dir.download()
    with open(os.path.join(mol_dir, "manifest.json")) as f:
        manifest = json.load(f)

    molecules_data = []
    lipinski_pass = 0

    for mol_info in manifest["molecules"]:
        mol = Chem.MolFromSmiles(mol_info["smiles"])
        if mol is None:
            continue

        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        hbd = Lipinski.NumHDonors(mol)
        hba = Lipinski.NumHAcceptors(mol)
        tpsa = Descriptors.TPSA(mol)
        rotatable = Lipinski.NumRotatableBonds(mol)
        formal_charge = Chem.GetFormalCharge(mol)
        num_rings = Lipinski.RingCount(mol)
        qed_score = qed(mol)

        # Lipinski Rule of Five
        lipinski = {
            "mw_ok": mw <= 500,
            "logp_ok": logp <= 5,
            "hbd_ok": hbd <= 5,
            "hba_ok": hba <= 10,
        }
        lipinski_all = all(lipinski.values())
        if lipinski_all:
            lipinski_pass += 1

        # Read image for data URI
        img_path = os.path.join(mol_dir, mol_info["image_file"])
        data_uri = ""
        if os.path.exists(img_path):
            with open(img_path, "rb") as img_f:
                b64 = base64.b64encode(img_f.read()).decode()
                data_uri = f"data:image/png;base64,{b64}"

        molecules_data.append({
            "name": mol_info["name"],
            "smiles": mol_info["smiles"],
            "mw": round(mw, 2),
            "logp": round(logp, 2),
            "hbd": hbd,
            "hba": hba,
            "tpsa": round(tpsa, 2),
            "rotatable_bonds": rotatable,
            "formal_charge": formal_charge,
            "num_rings": num_rings,
            "qed": round(qed_score, 4),
            "lipinski": lipinski,
            "lipinski_pass": lipinski_all,
            "image_data_uri": data_uri,
        })

    total = len(molecules_data)
    avg_mw = sum(m["mw"] for m in molecules_data) / total if total else 0
    avg_logp = sum(m["logp"] for m in molecules_data) / total if total else 0
    lipinski_rate = lipinski_pass / total * 100 if total else 0

    # ---- Build report ----
    html_parts = []

    # Header
    html_parts.append("<h2>Molecular Properties Analysis</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Total Molecules"),
        (f"{lipinski_rate:.0f}%", "Lipinski Pass Rate"),
        (f"{avg_mw:.1f}", "Avg. MW (Da)"),
        (f"{avg_logp:.2f}", "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Molecule gallery
    html_parts.append("<h3>Molecule Library</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in molecules_data:
        if m["image_data_uri"]:
            badge_class = "badge-success" if m["lipinski_pass"] else "badge-danger"
            badge_text = "Lipinski Pass" if m["lipinski_pass"] else "Lipinski Fail"
            html_parts.append(
                f'<div class="molecule-card" style="text-align:center;">'
                f'<img src="{m["image_data_uri"]}" style="width:160px;height:160px;object-fit:contain;"/>'
                f'<div style="font-weight:600;margin-top:6px;color:#0e4f6e;">{m["name"]}</div>'
                f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f}</div>'
                f'<div><span class="badge {badge_class}">{badge_text}</span></div>'
                f'</div>'
            )
    html_parts.append("</div>")

    # MW bar chart (horizontal, sorted)
    sorted_by_mw = sorted(molecules_data, key=lambda m: m["mw"], reverse=True)
    mw_labels = [m["name"] for m in sorted_by_mw]
    mw_vals = [m["mw"] for m in sorted_by_mw]
    mw_chart = _make_bar_chart(
        mw_labels, {"MW (Da)": mw_vals},
        title="Molecular Weight Distribution",
        horizontal=True,
        width=700, height=max(300, len(mw_labels) * 30 + 80),
        value_fmt=".1f",
    )
    html_parts.append("<h3>Molecular Weight</h3>")
    html_parts.append(f'<div class="chart-container">{mw_chart}</div>')

    # LogP vs MW scatter plot
    scatter_points = [
        {"x": m["mw"], "y": m["logp"], "label": m["name"]}
        for m in molecules_data
    ]
    scatter_chart = _make_scatter_plot(
        scatter_points,
        x_label="Molecular Weight (Da)",
        y_label="LogP",
        title="LogP vs. Molecular Weight (Lipinski Boundaries)",
        reference_lines=[
            {"axis": "x", "value": 500, "label": "MW = 500"},
            {"axis": "y", "value": 5, "label": "LogP = 5"},
        ],
        width=700,
        height=420,
    )
    html_parts.append("<h3>Lipinski Space</h3>")
    html_parts.append(f'<div class="chart-container">{scatter_chart}</div>')

    # Property heatmap (molecules x properties)
    prop_names = ["MW", "LogP", "HBD", "HBA", "TPSA", "Rot. Bonds"]
    # Normalize each property to 0-1 for heatmap
    raw_matrix = []
    for m in molecules_data:
        raw_matrix.append([m["mw"], m["logp"], m["hbd"], m["hba"], m["tpsa"], m["rotatable_bonds"]])

    # Normalize per column
    n_props = len(prop_names)
    col_min = [min(row[c] for row in raw_matrix) for c in range(n_props)]
    col_max = [max(row[c] for row in raw_matrix) for c in range(n_props)]
    norm_matrix = []
    for row in raw_matrix:
        norm_row = []
        for c in range(n_props):
            rng = col_max[c] - col_min[c]
            norm_row.append((row[c] - col_min[c]) / rng if rng else 0.5)
        norm_matrix.append(norm_row)

    heatmap_labels = [m["name"] for m in molecules_data]
    heatmap = _make_heatmap(
        norm_matrix, heatmap_labels, prop_names,
        title="Normalized Property Heatmap",
        color_scale="cyan",
        width=700,
        height=max(400, len(heatmap_labels) * 28 + 100),
    )
    html_parts.append("<h3>Property Heatmap</h3>")
    html_parts.append(f'<div class="chart-container">{heatmap}</div>')

    # Lipinski compliance table
    html_parts.append("<h3>Lipinski Rule of Five Compliance</h3>")
    html_parts.append("<table><tr><th>Molecule</th><th>MW &le; 500</th>"
                      "<th>LogP &le; 5</th><th>HBD &le; 5</th>"
                      "<th>HBA &le; 10</th><th>Overall</th></tr>")
    for m in molecules_data:
        lip = m["lipinski"]

        def _badge(ok):
            if ok:
                return '<span class="badge badge-success">Pass</span>'
            return '<span class="badge badge-danger">Fail</span>'

        overall_badge = _badge(m["lipinski_pass"])
        html_parts.append(
            f'<tr><td><strong>{m["name"]}</strong></td>'
            f'<td>{_badge(lip["mw_ok"])}</td>'
            f'<td>{_badge(lip["logp_ok"])}</td>'
            f'<td>{_badge(lip["hbd_ok"])}</td>'
            f'<td>{_badge(lip["hba_ok"])}</td>'
            f'<td>{overall_badge}</td></tr>'
        )
    html_parts.append("</table>")

    # QED bar chart
    sorted_by_qed = sorted(molecules_data, key=lambda m: m["qed"], reverse=True)
    qed_labels = [m["name"] for m in sorted_by_qed]
    qed_vals = [m["qed"] for m in sorted_by_qed]
    qed_chart = _make_bar_chart(
        qed_labels, {"QED Score": qed_vals},
        title="Drug-likeness (QED Score)",
        horizontal=True,
        width=700, height=max(300, len(qed_labels) * 30 + 80),
        value_fmt=".3f",
        colors=["#06d6a0"],
    )
    html_parts.append("<h3>Drug-likeness (QED)</h3>")
    html_parts.append(f'<div class="chart-container">{qed_chart}</div>')

    # Flush full report
    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # Return properties as JSON (strip image data URIs to reduce size)
    output = {
        "total": total,
        "lipinski_pass_count": lipinski_pass,
        "lipinski_pass_rate": round(lipinski_rate, 2),
        "avg_mw": round(avg_mw, 2),
        "avg_logp": round(avg_logp, 2),
        "molecules": [
            {k: v for k, v in m.items() if k != "image_data_uri"}
            for m in molecules_data
        ],
    }
    return json.dumps(output)

# ------------------------------------------------------------------
# Task 3: Screen candidates against target profile
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def screen_candidates(
    properties_json: str,
    target_profile: str = "",
) -> str:
    """Screen molecules against a target drug profile and rank candidates.

    Scores each molecule on how well it matches the target profile, computes
    pairwise Tanimoto similarity, and produces a ranked list.

    Args:
        properties_json: JSON from compute_properties.
        target_profile: JSON string with desired property ranges
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5.0]}).

    Returns:
        JSON string with ranked_molecules, similarity_matrix, similarity_labels,
        funnel, and target_profile. Pass the full return value verbatim to
        generate_report along with molecule_dir and properties_json.
    """
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Screening Candidates...</h2>"
                      "<p>Evaluating molecules against the target drug profile.</p>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    molecules = props["molecules"]

    # Default target profile
    if target_profile.strip():
        profile = json.loads(target_profile)
    else:
        profile = {
            "mw": [150, 500],
            "logp": [-0.5, 5.0],
            "hbd": [0, 5],
            "hba": [0, 10],
            "tpsa": [20, 140],
        }

    # --- Screening ---
    funnel_total = len(molecules)
    pass_mw = 0
    pass_logp = 0
    pass_lipinski = 0
    final_candidates = 0

    scored = []
    for m in molecules:
        score = 0
        max_score = 0
        criteria = {}

        # Check each profile criterion
        checks = [
            ("mw", m["mw"]),
            ("logp", m["logp"]),
            ("hbd", m["hbd"]),
            ("hba", m["hba"]),
            ("tpsa", m["tpsa"]),
        ]

        for key, val in checks:
            if key in profile:
                lo, hi = profile[key]
                max_score += 1
                in_range = lo <= val <= hi
                criteria[key] = in_range
                if in_range:
                    score += 1
                    # Bonus: closer to midpoint = higher score
                    mid = (lo + hi) / 2
                    rng = (hi - lo) / 2
                    dist = abs(val - mid) / rng if rng else 0
                    score += max(0, 0.5 * (1 - dist))

        # QED bonus
        score += m["qed"] * 2
        max_score += 2

        # Lipinski bonus
        if m["lipinski_pass"]:
            score += 1
        max_score += 1

        normalized_score = score / max_score if max_score else 0

        # Funnel tracking — cascading filter (each stage requires passing the previous)
        mw_ok = criteria.get("mw", True)
        logp_ok = criteria.get("logp", True)
        if mw_ok:
            pass_mw += 1
            if logp_ok:
                pass_logp += 1
                if m["lipinski_pass"]:
                    pass_lipinski += 1
                    if all(criteria.values()):
                        final_candidates += 1

        scored.append({
            **m,
            "screening_score": round(normalized_score, 4),
            "criteria_met": criteria,
            "all_criteria_met": all(criteria.values()),
        })

    # Sort by score descending
    scored.sort(key=lambda m: m["screening_score"], reverse=True)

    # --- Tanimoto similarity matrix ---
    fps = []
    valid_names = []
    for m in scored:
        mol = Chem.MolFromSmiles(m["smiles"])
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            fps.append(fp)
            valid_names.append(m["name"])

    similarity_matrix = []
    for i in range(len(fps)):
        row = []
        for j in range(len(fps)):
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            row.append(round(sim, 3))
        similarity_matrix.append(row)

    # ---- Build report ----
    html_parts = []
    html_parts.append("<h2>Candidate Screening Results</h2>")

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(funnel_total), "Total Screened"),
        (str(pass_lipinski), "Lipinski Passes"),
        (str(final_candidates), "All Criteria Met"),
        (f"{scored[0]['screening_score']:.3f}" if scored else "N/A", "Top Score"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # Screening funnel
    funnel_stages = [
        {"label": "Total Molecules", "count": funnel_total, "total": funnel_total},
        {"label": "Pass MW Filter", "count": pass_mw, "total": funnel_total},
        {"label": "Pass LogP Filter", "count": pass_logp, "total": funnel_total},
        {"label": "Lipinski Compliant", "count": pass_lipinski, "total": funnel_total},
        {"label": "All Criteria Met", "count": final_candidates, "total": funnel_total},
    ]
    funnel_svg = _make_funnel(
        funnel_stages,
        title="Screening Funnel",
        width=600,
        height=380,
    )
    html_parts.append("<h3>Screening Funnel</h3>")
    html_parts.append(f'<div class="chart-container" style="text-align:center;">{funnel_svg}</div>')

    # Ranked candidates table
    html_parts.append("<h3>Ranked Candidates</h3>")
    html_parts.append(
        "<table><tr><th>Rank</th><th>Molecule</th><th>Score</th>"
        "<th>MW</th><th>LogP</th><th>QED</th><th>Lipinski</th><th>All Criteria</th></tr>"
    )
    for rank, m in enumerate(scored, 1):
        lip_badge = ('<span class="badge badge-success">Pass</span>'
                     if m["lipinski_pass"]
                     else '<span class="badge badge-danger">Fail</span>')
        crit_badge = ('<span class="badge badge-success">Pass</span>'
                      if m["all_criteria_met"]
                      else '<span class="badge badge-danger">Fail</span>')
        # Highlight top 3
        row_style = ' style="background:#ecfeff;font-weight:600;"' if rank <= 3 else ""
        html_parts.append(
            f"<tr{row_style}><td>{rank}</td><td>{m['name']}</td>"
            f"<td>{m['screening_score']:.3f}</td>"
            f"<td>{m['mw']:.1f}</td><td>{m['logp']:.2f}</td>"
            f"<td>{m['qed']:.3f}</td><td>{lip_badge}</td><td>{crit_badge}</td></tr>"
        )
    html_parts.append("</table>")

    # Top 5 candidate cards with structures
    html_parts.append("<h3>Top 5 Candidates</h3>")
    html_parts.append('<div class="molecule-grid">')
    for m in scored[:5]:
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(250, 250)) if mol else ""
        badge_class = "badge-success" if m["all_criteria_met"] else "badge-info"
        badge_text = "All Criteria Met" if m["all_criteria_met"] else "Partial Match"
        html_parts.append(
            f'<div class="molecule-card" style="text-align:center;">'
            f'<img src="{img_uri}" style="width:140px;height:140px;object-fit:contain;"/>'
            f'<div style="font-weight:700;margin-top:6px;color:#0e4f6e;font-size:1.05em;">{m["name"]}</div>'
            f'<div style="font-size:0.85em;color:#155e75;margin:4px 0;">Score: {m["screening_score"]:.3f}</div>'
            f'<div style="font-size:0.8em;color:#6c757d;">MW: {m["mw"]:.1f} | LogP: {m["logp"]:.2f} | QED: {m["qed"]:.3f}</div>'
            f'<div style="margin-top:4px;"><span class="badge {badge_class}">{badge_text}</span></div>'
            f'</div>'
        )
    html_parts.append("</div>")

    # Tanimoto similarity heatmap
    if similarity_matrix:
        sim_heatmap = _make_heatmap(
            similarity_matrix, valid_names, valid_names,
            title="Pairwise Tanimoto Similarity (Morgan Fingerprints)",
            color_scale="cyan",
            width=700,
            height=max(500, len(valid_names) * 32 + 100),
        )
        html_parts.append("<h3>Chemical Similarity</h3>")
        html_parts.append(f'<div class="chart-container">{sim_heatmap}</div>')

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    output = {
        "ranked_molecules": scored,
        "similarity_matrix": similarity_matrix,
        "similarity_labels": valid_names,
        "funnel": funnel_stages,
        "target_profile": profile,
    }
    return json.dumps(output)

def _parse_screening_json(screening_json: str) -> dict:
    """Parse screening JSON from screen_candidates, with safe defaults.

    The agent must pass the exact tool return value. Partial or hand-built JSON
    is tolerated for optional similarity fields only.
    """
    screening = json.loads(screening_json)
    if "ranked_molecules" not in screening:
        raise ValueError(
            "screening_json must be the exact JSON string returned by "
            "screen_candidates (missing 'ranked_molecules'). Do not construct, "
            "truncate, or summarize tool output."
        )
    screening.setdefault("similarity_matrix", [])
    screening.setdefault("similarity_labels", [])
    return screening

# ------------------------------------------------------------------
# Task 4: Generate final comprehensive report
# ------------------------------------------------------------------

@tool
@env.task(report=True)
async def generate_report(
    molecule_dir: flyte.io.Dir,
    properties_json: str,
    screening_json: str,
) -> str:
    """Generate a comprehensive drug screening report.

    Produces an executive summary, top candidate spotlight cards, property
    distributions, chemical diversity analysis, and final recommendation.

    Args:
        molecule_dir: Directory from load_molecules.
        properties_json: JSON from compute_properties.
        screening_json: Exact verbatim JSON string returned by screen_candidates
            (must include ranked_molecules, similarity_matrix, similarity_labels).
            Do not construct or summarize this payload yourself.

    Returns:
        JSON summary with total_screened, lipinski_passes, all_criteria_met,
        top_candidate, top_score, and top_3 ranked molecules.
    """
    from rdkit import Chem

    await flyte.report.replace.aio(
        _wrap_report("<h2>Generating Final Report...</h2>"),
        do_flush=True,
    )

    props = json.loads(properties_json)
    screening = _parse_screening_json(screening_json)
    ranked = screening["ranked_molecules"]
    sim_matrix = screening["similarity_matrix"]
    sim_labels = screening["similarity_labels"]

    total = props["total"]
    lipinski_pass = props["lipinski_pass_count"]
    all_criteria = sum(1 for m in ranked if m["all_criteria_met"])
    top = ranked[0] if ranked else None

    html_parts = []

    # --- Executive Summary ---
    html_parts.append("<h2>Drug Molecule Screening Report</h2>")
    top_name = top["name"] if top else "N/A"
    top_score = f'{top["screening_score"]:.3f}' if top else "N/A"
    html_parts.append(
        f'<div class="card">'
        f'<h3 style="margin-top:0;color:#0e4f6e;">Executive Summary</h3>'
        f'<p style="font-size:1.05em;">'
        f'<strong>{total}</strong> molecules were screened against the target drug profile. '
        f'<strong>{lipinski_pass}</strong> passed Lipinski\'s Rule of Five, and '
        f'<strong>{all_criteria}</strong> met all screening criteria. '
        f'The top candidate is <strong style="color:#0891b2;">{top_name}</strong> '
        f'with a screening score of <strong>{top_score}</strong>.</p>'
        f'</div>'
    )

    # Stat grid
    html_parts.append('<div class="stat-grid">')
    for val, label in [
        (str(total), "Molecules Screened"),
        (str(lipinski_pass), "Lipinski Passes"),
        (str(all_criteria), "All Criteria Met"),
        (top_score, "Top Score"),
        (f'{props["avg_mw"]:.0f} Da', "Avg. Molecular Weight"),
        (f'{props["avg_logp"]:.2f}', "Avg. LogP"),
    ]:
        html_parts.append(
            f'<div class="stat"><div class="value">{val}</div>'
            f'<div class="label">{label}</div></div>'
        )
    html_parts.append("</div>")

    # --- Top 3 Candidate Spotlights ---
    html_parts.append("<h2>Top Candidate Spotlights</h2>")

    for rank, m in enumerate(ranked[:3], 1):
        mol = Chem.MolFromSmiles(m["smiles"])
        img_uri = _mol_to_data_uri(mol, size=(300, 300)) if mol else ""

        medal = ["gold", "silver", "#cd7f32"][rank - 1]
        medal_emoji = ["1st", "2nd", "3rd"][rank - 1]

        lip_badges = ""
        for rule, key in [("MW", "mw_ok"), ("LogP", "logp_ok"),
                          ("HBD", "hbd_ok"), ("HBA", "hba_ok")]:
            ok = m["lipinski"].get(key, False)
            cls = "badge-success" if ok else "badge-danger"
            lip_badges += f'<span class="badge {cls}" style="margin:2px;">{rule}</span> '

        html_parts.append(
            f'<div class="molecule-card" style="display:flex;gap:20px;align-items:flex-start;flex-wrap:wrap;">'
            f'<div style="text-align:center;min-width:180px;">'
            f'<div style="font-size:1.6em;font-weight:800;color:{medal};">{medal_emoji}</div>'
            f'<img src="{img_uri}" style="width:200px;height:200px;object-fit:contain;border-radius:8px;'
            f'border:2px solid #a5f3fc;"/>'
            f'<div style="font-weight:700;font-size:1.1em;color:#0e4f6e;margin-top:8px;">{m["name"]}</div>'
            f'</div>'
            f'<div style="flex:1;min-width:280px;">'
            f'<table style="margin:0;">'
            f'<tr><td><strong>SMILES</strong></td><td style="font-family:monospace;font-size:0.8em;word-break:break-all;">{m["smiles"]}</td></tr>'
            f'<tr><td><strong>Screening Score</strong></td><td style="font-weight:700;color:#0891b2;font-size:1.1em;">{m["screening_score"]:.3f}</td></tr>'
            f'<tr><td><strong>Molecular Weight</strong></td><td>{m["mw"]:.1f} Da</td></tr>'
            f'<tr><td><strong>LogP</strong></td><td>{m["logp"]:.2f}</td></tr>'
            f'<tr><td><strong>H-Bond Donors</strong></td><td>{m["hbd"]}</td></tr>'
            f'<tr><td><strong>H-Bond Acceptors</strong></td><td>{m["hba"]}</td></tr>'
            f'<tr><td><strong>TPSA</strong></td><td>{m["tpsa"]:.1f} A&sup2;</td></tr>'
            f'<tr><td><strong>Rotatable Bonds</strong></td><td>{m["rotatable_bonds"]}</td></tr>'
            f'<tr><td><strong>QED</strong></td><td>{m["qed"]:.4f}</td></tr>'
            f'<tr><td><strong>Lipinski Compliance</strong></td><td>{lip_badges}</td></tr>'
            f'</table>'
            f'</div>'
            f'</div>'
        )

    # --- Property Distribution (box-plot style as bars with min/max/median) ---
    html_parts.append("<h2>Property Distributions</h2>")

    prop_keys = [("mw", "Molecular Weight (Da)"), ("logp", "LogP"),
                 ("tpsa", "TPSA"), ("qed", "QED Score")]
    for key, label in prop_keys:
        vals = sorted([m[key] for m in ranked])
        n = len(vals)
        if n == 0:
            continue
        v_min = vals[0]
        v_max = vals[-1]
        median = vals[n // 2] if n % 2 == 1 else (vals[n // 2 - 1] + vals[n // 2]) / 2
        q1 = vals[n // 4] if n >= 4 else v_min
        q3 = vals[3 * n // 4] if n >= 4 else v_max

        # Simple horizontal box-plot as SVG
        box_w = 500
        box_h = 50
        margin_l = 10
        v_range = v_max - v_min or 1

        def sx(v):
            return margin_l + ((v - v_min) / v_range) * (box_w - 2 * margin_l)

        box_svg = (
            f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {box_w} {box_h}" '
            f'style="width:100%;max-width:{box_w}px;height:auto;">'
            f'<rect width="{box_w}" height="{box_h}" fill="#fff" rx="4"/>'
            # Whisker line
            f'<line x1="{sx(v_min):.1f}" y1="25" x2="{sx(v_max):.1f}" y2="25" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Min whisker
            f'<line x1="{sx(v_min):.1f}" y1="18" x2="{sx(v_min):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # Max whisker
            f'<line x1="{sx(v_max):.1f}" y1="18" x2="{sx(v_max):.1f}" y2="32" '
            f'stroke="#94a3b8" stroke-width="1.5"/>'
            # IQR box
            f'<rect x="{sx(q1):.1f}" y="14" width="{sx(q3) - sx(q1):.1f}" height="22" '
            f'fill="#a5f3fc" stroke="#0891b2" stroke-width="1.5" rx="3"/>'
            # Median line
            f'<line x1="{sx(median):.1f}" y1="12" x2="{sx(median):.1f}" y2="38" '
            f'stroke="#0e4f6e" stroke-width="2"/>'
            # Labels
            f'<text x="{sx(v_min):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_min:.1f}</text>'
            f'<text x="{sx(median):.1f}" y="10" text-anchor="middle" font-size="9" fill="#0e4f6e" font-weight="600">{median:.1f}</text>'
            f'<text x="{sx(v_max):.1f}" y="46" text-anchor="middle" font-size="9" fill="#6c757d">{v_max:.1f}</text>'
            f'</svg>'
        )
        html_parts.append(
            f'<div style="margin:8px 0;"><strong style="color:#155e75;">{label}</strong>'
            f'<div class="chart-container" style="padding:8px;">{box_svg}</div></div>'
        )

    # --- Chemical Diversity ---
    html_parts.append("<h2>Chemical Diversity Analysis</h2>")

    if sim_matrix and len(sim_matrix) > 1:
        # Compute average pairwise similarity (off-diagonal)
        n_mols = len(sim_matrix)
        off_diag = []
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                off_diag.append(sim_matrix[i][j])

        avg_sim = sum(off_diag) / len(off_diag) if off_diag else 0
        max_sim = max(off_diag) if off_diag else 0
        min_sim = min(off_diag) if off_diag else 0

        # Find most similar pair
        best_i, best_j = 0, 1
        best_val = 0
        for i in range(n_mols):
            for j in range(i + 1, n_mols):
                if sim_matrix[i][j] > best_val:
                    best_val = sim_matrix[i][j]
                    best_i, best_j = i, j

        html_parts.append('<div class="stat-grid">')
        html_parts.append(
            f'<div class="stat"><div class="value">{avg_sim:.3f}</div>'
            f'<div class="label">Avg. Pairwise Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{min_sim:.3f}</div>'
            f'<div class="label">Min Similarity</div></div>'
        )
        html_parts.append(
            f'<div class="stat"><div class="value">{max_sim:.3f}</div>'
            f'<div class="label">Max Similarity</div></div>'
        )
        html_parts.append("</div>")

        diversity_text = "highly diverse" if avg_sim < 0.3 else "moderately diverse" if avg_sim < 0.5 else "relatively similar"
        html_parts.append(
            f'<div class="note">'
            f'The library is <strong>{diversity_text}</strong> (avg. Tanimoto = {avg_sim:.3f}). '
            f'The most similar pair is <strong>{sim_labels[best_i]}</strong> and '
            f'<strong>{sim_labels[best_j]}</strong> (similarity = {best_val:.3f}).</div>'
        )

    # --- Recommendation ---
    html_parts.append("<h2>Recommendation</h2>")
    if top:
        html_parts.append(
            f'<div class="card">'
            f'<h3 style="margin-top:0;color:#0891b2;">Top Candidate: {top["name"]}</h3>'
            f'<p>Based on the virtual screening analysis, <strong>{top["name"]}</strong> '
            f'achieved the highest composite screening score of <strong>{top["screening_score"]:.3f}</strong>. '
        )

        reasons = []
        if top["lipinski_pass"]:
            reasons.append("full Lipinski Rule of Five compliance")
        if top["qed"] > 0.5:
            reasons.append(f"high drug-likeness (QED = {top['qed']:.3f})")
        if top.get("all_criteria_met"):
            reasons.append("all target profile criteria met")
        if top["mw"] <= 500:
            reasons.append(f"favorable molecular weight ({top['mw']:.1f} Da)")

        if reasons:
            html_parts.append(
                f'This candidate stands out due to: {", ".join(reasons)}.</p>'
            )
        else:
            html_parts.append("</p>")

        # Runner-up mentions
        if len(ranked) >= 2:
            html_parts.append(
                f'<p style="font-size:0.9em;color:#6c757d;">Runner-up candidates: '
            )
            runners = []
            for m in ranked[1:4]:
                runners.append(f'{m["name"]} (score: {m["screening_score"]:.3f})')
            html_parts.append(", ".join(runners) + ".</p>")

        html_parts.append("</div>")

    # Final note
    html_parts.append(
        '<div class="note">'
        "This is a virtual screening analysis. All candidates should undergo "
        "further computational validation (molecular dynamics, docking) and "
        "experimental testing before advancing to clinical trials.</div>"
    )

    await flyte.report.replace.aio(
        _wrap_report("\n".join(html_parts)),
        do_flush=True,
    )

    # JSON summary
    summary = {
        "total_screened": total,
        "lipinski_passes": lipinski_pass,
        "all_criteria_met": all_criteria,
        "top_candidate": top["name"] if top else None,
        "top_score": top["screening_score"] if top else None,
        "top_3": [
            {"name": m["name"], "score": m["screening_score"]}
            for m in ranked[:3]
        ],
    }
    return json.dumps(summary)

# ------------------------------------------------------------------
# Agent
# ------------------------------------------------------------------

# {{docs-fragment agent}}
SCREENING_AGENT_INSTRUCTIONS = """\
You are a medicinal chemistry screening strategist. You orchestrate a virtual \
screening pipeline using durable Flyte tools. You NEVER invent molecular \
properties — only RDKit tools compute them.

Workflow:
1. If target_profile is not provided in the user message, derive a JSON \
target_profile from the therapeutic brief. Valid keys: mw, logp, hbd, hba, tpsa \
(each [min, max]). Ground choices in oral bioavailability / kinase / CNS rules \
as appropriate to the brief.
2. First pass (always): load_molecules → compute_properties → \
screen_candidates → generate_report. Pass tool outputs between steps exactly \
(molecule_dir from load_molecules into compute_properties and generate_report; \
properties_json from compute_properties into screen_candidates and \
generate_report; screening_json must be the complete, unmodified string \
returned by screen_candidates — never rebuild or summarize JSON yourself).
3. Read the JSON summary returned by generate_report. Reflect:
   - If all_criteria_met == 0: relax exactly ONE profile bound by ~10–20% \
and re-run screen_candidates then generate_report only, reusing the same \
molecule_dir and properties_json from the first pass.
   - If all molecules pass but diversity is a stated goal: note high similarity \
in your summary; do not re-run unless brief asks for stricter filters.
   - Maximum ONE rescreen iteration.
4. Finish with plain text: top candidate, rationale tied to computed metrics \
from the tool JSON, funnel interpretation, and suggested next steps (docking, \
ADMET lab tests).

If the user supplies an explicit target_profile JSON, use it as-is.

Do NOT ask the user for SMILES or molecule lists when molecules_json is empty — \
the default library is loaded automatically.
"""

screening_agent = Agent(
    name="drug-screening-agent",
    instructions=SCREENING_AGENT_INSTRUCTIONS,
    model=MODEL,
    tools=[
        load_molecules,
        compute_properties,
        screen_candidates,
        generate_report,
    ],
    max_turns=12,
)
# {{/docs-fragment agent}}

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def pipeline(
    brief: str = "Screen the default drug library for orally bioavailable small molecules.",
    molecules_json: str = "",
    target_profile: str = "",
) -> str:
    """Agentic virtual drug molecule screening pipeline.

    A medicinal-chemistry agent interprets the screening brief, derives or
    applies a target profile, orchestrates the RDKit screening stages, and
    optionally re-screens when funnel results are too narrow.

    Args:
        brief: Natural-language therapeutic goal (e.g. oral kinase inhibitors,
            CNS-penetrant small molecules).
        molecules_json: JSON mapping molecule names to SMILES strings.
            Defaults to a curated library of ~15 well-known drugs.
        target_profile: Optional JSON with desired property ranges that
            overrides agent-derived criteria
            (e.g. {"mw": [150, 500], "logp": [-0.5, 5]}).

    Returns:
        Agent summary with screening rationale and key results.
    """
    prompt_parts = [
        f"Screening brief: {brief}",
        'Use molecules_json="" for the built-in default library unless provided below.',
        "Compose the four stage tools in order: load_molecules → compute_properties "
        "→ screen_candidates → generate_report. Pass each tool's full return value "
        "verbatim to the next step (especially screening_json). Re-run "
        "screen_candidates and generate_report at most once if the funnel is too narrow.",
    ]
    if molecules_json.strip():
        prompt_parts.append(f"molecules_json: {molecules_json}")
    if target_profile.strip():
        prompt_parts.append(f"Use this target_profile exactly: {target_profile}")

    result = await screening_agent.run.aio("\n".join(prompt_parts))
    return result.summary or result.error or ""

# {{/docs-fragment pipeline}}

# ------------------------------------------------------------------
# Rescreen demo — tight profile + explicit rescreen instructions
# ------------------------------------------------------------------

# Initial profile is deliberately strict (narrow MW + low LogP cap) so
# all_criteria_met is typically 0 on the default library; the brief then
# forces a single rescreen with a widened LogP window.
RESCREEN_DEMO_TARGET_PROFILE = (
    '{"mw": [150, 200], "logp": [-0.5, 1.0], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_TARGET_PROFILE_RESCREEN = (
    '{"mw": [150, 200], "logp": [-0.5, 3.5], "hbd": [0, 1], '
    '"hba": [0, 3], "tpsa": [20, 45]}'
)
RESCREEN_DEMO_BRIEF = f"""\
Two-round agentic screening demo on the default library.

**Round 1 (strict profile):** load_molecules → compute_properties → \
screen_candidates → generate_report using the initial target_profile exactly.

**Round 2 (required — do not skip):** call screen_candidates then generate_report \
again, reusing the same molecule_dir and properties_json from round 1, with this \
relaxed target_profile (wider LogP window only): \
{RESCREEN_DEMO_TARGET_PROFILE_RESCREEN}

Pass every tool return value verbatim to the next step. After both rounds, \
summarize how the funnel and top candidates changed between round 1 and round 2."""

# {{docs-fragment rescreen_demo}}
@env.task(report=True)
async def rescreen_demo() -> str:
    """Example run with a two-round execution graph (rescreen).

    Round 1 uses a strict CNS-like profile; round 2 always re-runs
    screen_candidates and generate_report with a widened LogP window,
    reusing cached molecule_dir and properties_json.
    """
    return await pipeline(
        brief=RESCREEN_DEMO_BRIEF,
        target_profile=RESCREEN_DEMO_TARGET_PROFILE,
    )

# {{/docs-fragment rescreen_demo}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/drug_molecule_screening/drug_molecule_screening.py*

```
flyte run drug_molecule_screening.py rescreen_demo
```

Or pass the same inputs to `pipeline` directly:

```
flyte run drug_molecule_screening.py pipeline \
  --brief "Screen the default library. If all_criteria_met is 0 after generate_report, re-run screen_candidates and generate_report with target_profile {\"mw\": [150, 200], \"logp\": [-0.5, 3.5], \"hbd\": [0, 1], \"hba\": [0, 3], \"tpsa\": [20, 45]}." \
  --target_profile '{"mw": [150, 200], "logp": [-0.5, 1.0], "hbd": [0, 1], "hba": [0, 3], "tpsa": [20, 45]}'
```

Open the run URL and follow the report panel for funnel charts, property distributions, top-candidate spotlights, and the agent's final screening summary. A successful rescreen demo shows two rounds of `screen_candidates` and `generate_report` in the action tree.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/geospatial ===

# Geospatial

Tutorials for satellite imagery, remote sensing, and earth and atmospheric modeling workloads.

### **Geospatial > GPU-accelerated climate modeling**

Run ensemble atmospheric simulations on H200 GPUs with multi-source data ingestion and real-time extreme event detection.

### **Geospatial > Satellite image classification**

Build a production-grade EfficientNet pipeline for land-use classification with caching, experiment tracking, and reporting.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/geospatial/climate-modeling ===

# GPU-accelerated climate modeling

Climate modeling is hard for two reasons: data and compute. Satellite imagery arrives continuously from multiple sources. Reanalysis datasets have to be pulled from remote APIs. Weather station data shows up in different formats and schemas. And once all of that is finally in one place, running atmospheric physics simulations demands serious GPU compute.

In practice, many climate workflows are held together with scripts, cron jobs, and a lot of manual babysitting. Data ingestion breaks without warning. GPU jobs run overnight with little visibility into what's happening. When something interesting shows up in a simulation, like a developing hurricane, no one notices until the job finishes hours later.

In this tutorial, we build a production-grade climate modeling pipeline using Flyte. We ingest data from three different sources in parallel, combine it with Dask, run ensemble atmospheric simulations on H200 GPUs, detect extreme weather events as they emerge, and visualize everything in a live dashboard. The entire pipeline is orchestrated, cached, and fault-tolerant, so it can run reliably at scale.

![Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/climate-modeling/report.png)

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/climate_modeling/simulation.py).

## Overview

We're building an ensemble weather forecasting system. Ensemble forecasting runs the same simulation multiple times with slightly different initial conditions. This quantifies forecast uncertainty. Instead of saying "the temperature will be 25°C", we can say "the temperature will be 24-26°C with 90% confidence".

The pipeline has five stages:

1. **Data ingestion**: Pull satellite imagery from NOAA GOES, reanalysis data from ERA5, and surface observations from weather stations in parallel.
2. **Preprocessing**: Fuse the datasets, interpolate to a common grid, and run quality control using Dask for distributed computation.
3. **GPU simulation**: Run ensemble atmospheric physics on H200 GPUs. Each ensemble member evolves independently. PyTorch handles the tensor operations; `torch.compile` optimizes the kernels.
4. **Event detection**: Monitor for hurricanes (high wind + low pressure) and heatwaves during simulation. When extreme events are detected, the pipeline can adaptively refine the grid resolution.
5. **Real-time reporting**: Stream metrics to a live Flyte Reports dashboard showing convergence and detected events.

This workflow is a good example of where Flyte shines!

- **Parallel data ingestion**: Three different data sources, three different APIs, all running concurrently. Flyte's async task execution handles this naturally.
- **Resource heterogeneity**: Data ingestion needs CPU and network. Preprocessing needs a Dask cluster. Simulation needs GPUs. Flyte provisions exactly what each stage needs.
- **Caching**: ERA5 data fetches can take minutes. Run the pipeline twice with the same date range, and Flyte skips the fetch entirely.
- **Adaptive workflows**: When a hurricane is detected, we can dynamically refine the simulation. Flyte makes this kind of conditional logic straightforward.

## Implementation

### Dependencies and container image

```
import asyncio
import gc
import io
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Literal

import flyte
import numpy as np
import pandas as pd
import xarray as xr
from flyte.io import File
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The key imports include `xarray` for multi-dimensional climate data, `flyteplugins.dask` for distributed preprocessing, and `flyte` for orchestration.

```
climate_image = (
    flyte.Image.from_debian_base(name="climate_modeling_h200")
    .with_apt_packages(
        "libnetcdf-dev",  # NetCDF for climate data
        "libhdf5-dev",  # HDF5 for large datasets
        "libeccodes-dev",  # GRIB format support (ECMWF's native format)
        "libudunits2-dev",  # Unit conversions
    )
    .with_pip_packages(
        "numpy==2.3.5",
        "pandas==2.3.3",
        "xarray==2025.11.0",
        "torch==2.9.1",
        "netCDF4==1.7.3",
        "s3fs==2025.10.0",
        "aiohttp==3.13.2",
        "ecmwf-datastores-client==0.4.1",
        "h5netcdf==1.7.3",
        "cfgrib==0.9.15.1",
        "pyarrow==22.0.0",
        "scipy==1.15.1",
        "flyteplugins-dask>=2.0.0b33",
        "nvidia-ml-py3==7.352.0",
    )
    .with_env_vars({"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"})
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Climate data comes in specialized formats such as NetCDF, HDF5, and GRIB. The container image includes libraries to work with all of them, along with PyTorch for GPU computation and the ECMWF client for accessing ERA5 data.

### Simulation parameters and data structures

```
@dataclass
class SimulationParams:
    grid_resolution_km: float = 10.0
    time_step_minutes: int = 10
    simulation_hours: int = 240
    physics_model: Literal["WRF", "MPAS", "CAM"] = "WRF"
    boundary_layer_scheme: str = "YSU"
    microphysics_scheme: str = "Thompson"
    radiation_scheme: str = "RRTMG"

    # Ensemble forecasting parameters
    ensemble_size: int = 800
    perturbation_magnitude: float = 0.5

    # Convergence criteria for adaptive refinement
    convergence_threshold: float = 0.1  # 10% of initial ensemble spread
    max_iterations: int = 3

@dataclass
class ClimateMetrics:
    timestamp: str
    iteration: int
    convergence_rate: float
    energy_conservation_error: float
    max_wind_speed_mps: float
    min_pressure_mb: float
    detected_phenomena: list[str]
    compute_time_seconds: float
    ensemble_spread: float

@dataclass
class SimulationSummary:
    total_iterations: int
    final_resolution_km: float
    avg_convergence_rate: float
    total_compute_time_seconds: float
    hurricanes_detected: int
    heatwaves_detected: int
    converged: bool
    region: str
    output_files: list[File]
    date_range: list[str, str]
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`SimulationParams` defines the core behavior of the simulation, including grid resolution, physics schemes, and ensemble size. The default configuration runs 800 ensemble members, which is sufficient to produce statistically meaningful uncertainty estimates.

> [!NOTE]
> Decreasing the grid spacing via `grid_resolution_km` (for example, from 10 km to 5 km) increases grid resolution and significantly increases memory usage because it introduces more data points and intermediate state. Even with 141 GB of H200 GPU memory, high-resolution or adaptively refined simulations may exceed available VRAM, especially when running large ensembles.
>
> To mitigate this, consider reducing the ensemble size, limiting the refined region, running fewer physics variables, or scaling the simulation across more GPUs so memory is distributed more evenly.

`ClimateMetrics` collects diagnostics at each iteration, such as convergence rate, energy conservation, and detected phenomena. These metrics are streamed to the real-time dashboard so you can monitor how the simulation evolves as it runs.

### Task environments

Different stages need different resources. Flyte's `TaskEnvironment` declares exactly what each task requires:

```
gpu_env = flyte.TaskEnvironment(
    name="climate_modeling_gpu",
    resources=flyte.Resources(
        cpu=5,
        memory="130Gi",
        gpu="H200:1",
    ),
    image=climate_image,
    cache="auto",
)

dask_env = flyte.TaskEnvironment(
    name="climate_modeling_dask",
    plugin_config=Dask(
        scheduler=Scheduler(resources=flyte.Resources(cpu=2, memory="6Gi")),
        workers=WorkerGroup(
            number_of_workers=2,
            resources=flyte.Resources(cpu=2, memory="12Gi"),
        ),
    ),
    image=climate_image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),  # Head node
    cache="auto",
)

cpu_env = flyte.TaskEnvironment(
    name="climate_modeling_cpu",
    resources=flyte.Resources(cpu=8, memory="64Gi"),
    image=climate_image,
    cache="auto",
    secrets=[
        flyte.Secret(key="cds_api_key", as_env_var="ECMWF_DATASTORES_KEY"),
        flyte.Secret(key="cds_api_url", as_env_var="ECMWF_DATASTORES_URL"),
    ],
    depends_on=[gpu_env, dask_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Here’s what each environment is responsible for:

- **`gpu_env`**: Runs the atmospheric simulations on H200 GPUs. The 130 GB of GPU memory is used to hold the ensemble members in VRAM during execution.
- **`dask_env`**: Provides a distributed Dask cluster for preprocessing. A scheduler and multiple workers handle data fusion and transformation in parallel.
- **`cpu_env`**: Handles data ingestion and orchestration. This environment also includes the secrets required to access the ERA5 API.

The `depends_on` setting on `cpu_env` ensures that Flyte builds the GPU and Dask images first. Once those environments are ready, the orchestration task can launch the specialized simulation and preprocessing tasks.

### Data ingestion: multiple sources in parallel

Climate models need data from multiple sources. Each source has different formats, APIs, and failure modes. We handle them as separate Flyte tasks that run concurrently.

**Satellite imagery from NOAA GOES**

```
@cpu_env.task
async def ingest_satellite_data(region: str, date_range: list[str, str]) -> File:
    """Ingest GOES satellite imagery from NOAA's public S3 buckets."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

This task fetches cloud imagery and precipitable water products from NOAA's public S3 buckets. GOES-16 covers the Atlantic; GOES-17 covers the Pacific. The task selects the appropriate satellite based on region, fetches multiple days in parallel using `asyncio.gather`, and combines everything into a single xarray Dataset.

**ERA5 reanalysis from Copernicus**

```
@cpu_env.task
async def ingest_reanalysis_data(region: str, date_range: list[str, str]) -> File:
    """Fetch ERA5 reanalysis from Copernicus Climate Data Store."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

ERA5 provides 3D atmospheric fields such as temperature, wind, humidity at multiple pressure levels from surface to stratosphere. The ECMWF datastores client handles authentication via Flyte secrets. Each day fetches in parallel, then gets concatenated.

**Surface observations from weather stations:**

```
@cpu_env.task
async def ingest_station_data(
    region: str, date_range: list[str, str], max_stations: int = 100
) -> File:
    """Fetch ground observations from NOAA's Integrated Surface Database."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Ground truth comes from NOAA's Integrated Surface Database. The task filters stations by geographic bounds, fetches hourly observations, and returns a Parquet file for efficient downstream processing.

All three tasks return Flyte `File` objects that hold references to data in blob storage. No data moves until a downstream task actually needs it.

### Preprocessing with Dask

The three data sources need to be combined into a unified atmospheric state. This means:
- Interpolating to a common grid
- Handling missing values
- Merging variables from different sources
- Quality control

This is a perfect fit for Dask to handle lazy evaluation over chunked arrays:

```python
@dask_env.task
async def preprocess_atmospheric_data(
    satellite_data: File,
    reanalysis_data: File,
    station_data: File,
    target_resolution_km: float,
) -> File:
```

This task connects to the Dask cluster provisioned by Flyte, loads the datasets with appropriate chunking, merges satellite and reanalysis grids, fills in missing values, and persists the result. Flyte caches the output, so preprocessing only runs when the inputs change.

### GPU-accelerated atmospheric simulation

Now the core: running atmospheric physics on the GPU. Each ensemble member is an independent forecast with slightly perturbed initial conditions.

```
@gpu_env.task
async def run_atmospheric_simulation(
    input_data: File,
    params: SimulationParams,
    partition_id: int = 0,
    ensemble_start: int | None = None,
    ensemble_end: int | None = None,
) -> tuple[File, ClimateMetrics]:
    """Run GPU-accelerated atmospheric simulation with ensemble forecasting."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The task accepts a subset of ensemble members (`ensemble_start` to `ensemble_end`). This enables distributing 800 members across multiple GPUs.

The physics step is the computational kernel. It runs advection (wind transport), pressure gradients, Coriolis forces, turbulent diffusion, and moisture condensation:

```
    @torch.compile(mode="reduce-overhead")
    def physics_step(state_tensor, dt_val, dx_val):
        """Compiled atmospheric physics - 3-4x faster with torch.compile."""
        # Advection: transport by wind
        temp_grad_x = torch.roll(state_tensor[:, 0], -1, dims=2) - torch.roll(
            state_tensor[:, 0], 1, dims=2
        )
        temp_grad_y = torch.roll(state_tensor[:, 0], -1, dims=3) - torch.roll(
            state_tensor[:, 0], 1, dims=3
        )
        advection = -(
            state_tensor[:, 3] * temp_grad_x + state_tensor[:, 4] * temp_grad_y
        ) / (2 * dx_val)
        state_tensor[:, 0] = state_tensor[:, 0] + advection * dt_val

        # Pressure gradient with Coriolis
        pressure_grad_x = (
            torch.roll(state_tensor[:, 1], -1, dims=2)
            - torch.roll(state_tensor[:, 1], 1, dims=2)
        ) / (2 * dx_val)
        pressure_grad_y = (
            torch.roll(state_tensor[:, 1], -1, dims=3)
            - torch.roll(state_tensor[:, 1], 1, dims=3)
        ) / (2 * dx_val)

        coriolis_param = 1e-4  # ~45°N latitude
        coriolis_u = coriolis_param * state_tensor[:, 4]
        coriolis_v = -coriolis_param * state_tensor[:, 3]

        state_tensor[:, 3] = (
            state_tensor[:, 3] - pressure_grad_x * dt_val * 0.01 + coriolis_u * dt_val
        )
        state_tensor[:, 4] = (
            state_tensor[:, 4] - pressure_grad_y * dt_val * 0.01 + coriolis_v * dt_val
        )

        # Turbulent diffusion
        diffusion_coeff = 10.0
        laplacian_temp = (
            torch.roll(state_tensor[:, 0], 1, dims=2)
            + torch.roll(state_tensor[:, 0], -1, dims=2)
            + torch.roll(state_tensor[:, 0], 1, dims=3)
            + torch.roll(state_tensor[:, 0], -1, dims=3)
            - 4 * state_tensor[:, 0]
        ) / (dx_val * dx_val)
        state_tensor[:, 0] = (
            state_tensor[:, 0] + diffusion_coeff * laplacian_temp * dt_val
        )

        # Moisture condensation
        sat_vapor_pressure = 611.2 * torch.exp(
            17.67 * state_tensor[:, 0] / (state_tensor[:, 0] + 243.5)
        )
        condensation = torch.clamp(
            state_tensor[:, 2] - sat_vapor_pressure * 0.001, min=0
        )
        state_tensor[:, 2] = state_tensor[:, 2] - condensation * 0.1
        state_tensor[:, 0] = state_tensor[:, 0] + condensation * 2.5e6 / 1005 * dt_val

        return state_tensor
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`@torch.compile(mode="reduce-overhead")` compiles this function into optimized CUDA kernels. Combined with mixed precision (`torch.cuda.amp.autocast`), this runs 3-4x faster than eager PyTorch.

Every 10 timesteps, the simulation checks for extreme events:
- **Hurricanes**: Wind speed > 33 m/s with low pressure
- **Heatwaves**: Temperature anomalies exceeding thresholds

Detected phenomena get logged to the metrics, which flow to the live dashboard.

### Distributing across multiple GPUs

800 ensemble members is a lot for one GPU, so we distribute them:

```
@cpu_env.task
async def run_distributed_simulation_ensemble(
    preprocessed_data: File, params: SimulationParams, n_gpus: int
) -> tuple[list[File], list[ClimateMetrics]]:
    total_members = params.ensemble_size
    members_per_gpu = total_members // n_gpus

    # Distribute ensemble members across GPUs
    tasks = []
    for gpu_id in range(n_gpus):
        # Calculate ensemble range for this GPU
        ensemble_start = gpu_id * members_per_gpu
        # Last GPU gets any remainder members
        if gpu_id == n_gpus - 1:
            ensemble_end = total_members
        else:
            ensemble_end = ensemble_start + members_per_gpu

        # Launch GPU task with ensemble subset
        gpu_task = run_atmospheric_simulation(
            preprocessed_data,
            params,
            gpu_id,
            ensemble_start=ensemble_start,
            ensemble_end=ensemble_end,
        )
        tasks.append(gpu_task)

    # Execute all GPUs in parallel
    results = await asyncio.gather(*tasks)

    output_files = [r[0] for r in results]
    metrics = [r[1] for r in results]

    return output_files, metrics
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The task splits the ensemble members evenly across the available GPUs, launches the simulation runs in parallel using `asyncio.gather`, and then aggregates the results. With five GPUs, each GPU runs 160 ensemble members. Flyte takes care of scheduling, so GPU tasks start automatically as soon as resources become available.

### The main workflow

Everything comes together in the orchestration task:

```
@cpu_env.task(report=True)
async def adaptive_climate_modeling_workflow(
    region: str = "atlantic",
    date_range: list[str, str] = ["2024-09-01", "2024-09-10"],
    current_params: SimulationParams = SimulationParams(),
    enable_multi_gpu: bool = True,
    n_gpus: int = 5,
) -> SimulationSummary:
    """Orchestrates multi-source ingestion, GPU simulation, and adaptive refinement."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`report=True` enables Flyte Reports for live monitoring.

```
    # Parallel data ingestion from three sources
    with flyte.group("data-ingestion"):
        satellite_task = ingest_satellite_data(region, date_range)
        reanalysis_task = ingest_reanalysis_data(region, date_range)
        station_task = ingest_station_data(region, date_range)

        satellite_data, reanalysis_data, station_data = await asyncio.gather(
            satellite_task,
            reanalysis_task,
            station_task,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`flyte.group("data-ingestion")` visually groups the ingestion tasks in the Flyte UI. Inside the group, three tasks launch concurrently. `asyncio.gather` waits for all three to complete before preprocessing begins.

The workflow then enters an iterative loop:
1. Run GPU simulation (single or multi-GPU)
2. Check convergence by comparing forecasts across iterations
3. Detect extreme events
4. If a hurricane is detected and we haven't refined yet, double the grid resolution
5. Stream metrics to the live dashboard
6. Repeat until converged or max iterations reached

Adaptive mesh refinement is the key feature here. When the simulation detects a hurricane forming, it automatically increases resolution to capture the fine-scale dynamics. This is expensive, so we limit it to one refinement per run.

### Running the pipeline

```
if __name__ == "__main__":
    flyte.init_from_config()
    run_multi_gpu = flyte.run(adaptive_climate_modeling_workflow)

    print(f"Run URL: {run_multi_gpu.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Before running, set up ERA5 API credentials:

```bash
flyte create secret cds_api_key <YOUR_CDS_API_KEY>
flyte create secret cds_api_url https://cds.climate.copernicus.eu/api
```

Then launch:

```bash
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run simulation.py
```

The default configuration uses the Atlantic region for September 2024, which is hurricane season.

## Key concepts

### Ensemble forecasting

Weather prediction is inherently uncertain. Small errors in the initial conditions grow over time due to chaotic dynamics, which means a single forecast can only ever be one possible outcome.

Ensemble forecasting addresses this uncertainty by:
- Perturbing the initial conditions within known observational error bounds
- Running many independent forecasts
- Computing the ensemble mean as the most likely outcome and the ensemble spread as a measure of uncertainty

### Adaptive mesh refinement

When a hurricane begins to form, coarse spatial grids are not sufficient to resolve critical features like eyewall dynamics. Adaptive mesh refinement allows the simulation to focus compute where it matters most by:
- Increasing grid resolution, for example from 10 km to 5 km
- Reducing the timestep to maintain numerical stability
- Refining only the regions of interest instead of the entire domain

This approach is computationally expensive, but it is essential for producing accurate intensity forecasts.

### Real-time event detection

Rather than analyzing results after a simulation completes, this pipeline detects significant events as the simulation runs.

The system monitors for conditions such as:
- **Hurricanes**: Wind speeds exceeding 33 m/s (Category 1 threshold) combined with central pressure below 980 mb
- **Heatwaves**: Sustained temperature anomalies over a defined period

Detecting these events in real time enables adaptive responses, such as refining the simulation or triggering alerts, and supports earlier warnings for extreme weather.

## Where to go next

This example is intentionally scoped to keep the ideas clear, but there are several natural ways to extend it for more realistic workloads.

To model different ocean basins, change the `region` parameter to values like `"pacific"` or `"indian"`. The ingestion tasks automatically adjust to pull the appropriate satellite coverage for each region.

To run longer forecasts, increase `simulation_hours` in `SimulationParams`. The default of 240 hours, or 10 days, is typical for medium-range forecasting, but you can run longer simulations if you have the compute budget.

Finally, the physics step here is deliberately simplified. Production systems usually incorporate additional components such as radiation schemes, boundary layer parameterizations, and land surface models. These can be added incrementally as separate steps without changing the overall structure of the pipeline.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/geospatial/satellite_image_classification ===

# Satellite image classification

![Satellite Image](https://www.union.ai/docs/v2/union/_static/images/tutorials/satellite_image_classification/satellite_image.png)

## Background

Remote sensing has transformed how we monitor our planet. From tracking deforestation to detecting urban sprawl, satellite imagery provides a bird's-eye view of land use change at global scale.

But training a model that can reliably classify that imagery — across 10 distinct land-use categories, at production quality — requires more than just a good model. It requires a pipeline that handles data, compute, caching, experiment tracking, and reporting as first-class concerns.

This tutorial walks through a complete satellite image classification pipeline built on Union.ai, using EfficientNet-B0, a two-phase training strategy, and Weights & Biases for experiment tracking.

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/satellite_image_classification).

## Dataset

[EuroSAT](https://github.com/phelber/EuroSAT) is a benchmark dataset of 27,000 labeled satellite images drawn from the Sentinel-2 satellite. Each image is 64×64 pixels across 10 land-use classes: Annual Crop, Forest, Herbaceous Vegetation, Highway, Industrial, Pasture, Permanent Crop, Residential, River, and Sea/Lake.

It's a well-structured dataset - balanced, clearly labeled - which makes it ideal for demonstrating a production-grade training pipeline without the overhead of massive data infrastructure.

## Model

We use EfficientNet-B0 from `timm` (the PyTorch Image Models library), pretrained on ImageNet. EfficientNet was designed to scale depth, width, and resolution jointly using a compound coefficient, giving strong accuracy with a relatively small parameter count (~5.3M). The ImageNet pretraining means the backbone already understands edges, textures, and shapes - features that transfer well to satellite imagery.

## Two-phase training

Fine-tuning a pretrained model naively by using all weights immediately often leads to catastrophic forgetting: the model destroys its learned representations before the new task-specific head has had a chance to stabilize.

Instead, we use a two-phase approach:

**Phase 1: Feature Extraction (frozen backbone).** The EfficientNet backbone is frozen. Only the classification head is trained, at a relatively high learning rate (2e-3). This gives the head 7 epochs to learn to map ImageNet features to EuroSAT categories, without disturbing the pretrained weights.

**Phase 2: Fine-tuning (unfrozen backbone).** The backbone is unfrozen and added to the optimizer with a 10× lower learning rate than the head (`phase2_lr` × 0.1). A fresh cosine annealing schedule is initialized over the remaining steps, so the learning rate doesn't arrive near-zero from Phase 1's schedule before Phase 2 even begins. This lets the backbone adapt to satellite-specific features while preserving the general representations it learned on ImageNet.

The transition happens automatically inside a `PhaseChangeCallback`:

```
    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/training.py*

This two-phase strategy consistently reaches >95% validation accuracy on EuroSAT within 17 total epochs.

## Pipeline

Training a model is only part of the story. The real challenge is building a system that is reproducible, cost-efficient, and easy to iterate on. That's where Union's `TaskEnvironment` model shines: each stage of the pipeline runs in the right compute environment, and results are cached so you never pay for work you've already done.

The pipeline has four components, each with its own environment defined in `config.py`.

### Task 1: Data download (`dataset_env`)

```
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw EuroSAT JPEG files and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_eurosat()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/run.py*

This task downloads the raw EuroSAT JPEG files via torchvision and packages them as a `flyte.io.Dir`. It runs on a lightweight CPU container (2 cores, 2 GB RAM) - no GPU needed. With `cache="auto"`, the result is stored and reused on every subsequent run. You pay for the download exactly once.

No preprocessing happens here. Raw images are passed directly to training so that all transforms - resize, normalization, and augmentation - happen per-batch with the full training context, giving the model properly prepared 224×224 input from the original pixels.

### Task 2: GPU training (`training_env`)

```
@wandb_init
@training_env.task
async def train_model(dataset_dir: Dir, config: TrainingConfig) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/eurosat_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    result = train_satellite_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))

    return await Dir.from_local(str(output_dir))
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/run.py*

This task runs on a T4 GPU with 32 GB RAM. It receives the dataset `Dir` from Task 1, downloads it locally, then runs the two-phase training loop using PyTorch Lightning.

Two things worth noting:

- With `cache="auto"`, training results are cached based on the input data and config. If you rerun the pipeline with the same dataset and hyperparameters, Union skips training entirely and returns the cached metrics. This makes hyperparameter search much cheaper: only configurations you haven't tried before actually execute.

- `@wandb_init` — the `flyteplugins-wandb` integration initializes a W&B run automatically and makes it available via `get_wandb_run()`. This means every training run automatically logs metrics, learning rate curves, and t-SNE visualizations of the learned feature space to your W&B project.

```
    wandb_logger = WandbLogger(experiment=get_wandb_run(), log_model=False)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/training.py*

### Task 3: Report generation (`report_env`)

This task reads the `metrics.json` produced by training and renders interactive Plotly charts - validation accuracy and train/val loss curves - directly in the Union UI. The `report=True` flag tells Union to render the task output as a rich report panel. A dashed vertical line marks the Phase 1 → Phase 2 transition, making it easy to see how much the backbone fine-tuning contributes.

```
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves
    in the Union UI report panel.
    """
    import plotly.graph_objects as go
    from pathlib import Path

    local_dir = Path("/tmp/training_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())

    epochs = [e["epoch"] for e in history]
    val_acc = [e["val_acc"] for e in history]
    val_loss = [e["val_loss"] for e in history]
    train_loss = [e["train_loss"] for e in history]
    # phase_boundary: first epoch where phase 2 begins (frozen → fine-tune transition)
    phase_boundary = next((e["epoch"] for e in history if e["phase"] == 2), None)

    def add_phase_line(fig):
        if phase_boundary is not None:
            fig.add_vline(
                x=phase_boundary,
                line_dash="dash",
                line_color="gray",
                annotation_text="Phase 2 start",
            )

    acc_fig = go.Figure()
    acc_fig.add_trace(go.Scatter(x=epochs, y=val_acc, mode="lines+markers", name="Val Accuracy"))
    acc_fig.update_layout(title="Validation Accuracy", xaxis_title="Epoch", yaxis_title="Accuracy")
    add_phase_line(acc_fig)

    loss_fig = go.Figure()
    loss_fig.add_trace(go.Scatter(x=epochs, y=train_loss, mode="lines+markers", name="Train Loss"))
    loss_fig.add_trace(go.Scatter(x=epochs, y=val_loss, mode="lines+markers", name="Val Loss"))
    loss_fig.update_layout(title="Loss", xaxis_title="Epoch", yaxis_title="Loss")
    add_phase_line(loss_fig)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/run.py*

### Task 4: Orchestration (`pipeline_env`)

The pipeline task is a lightweight orchestrator. It has no heavy dependencies of its own, just enough to call the three tasks above in sequence. The `async`/`await` pattern means each task handoff is non-blocking: Union manages scheduling, retries, and data movement between tasks transparently.

```
@pipeline_env.task
async def satellite_classification_pipeline(config: TrainingConfig) -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config=config,
    )
    await create_report(results_dir=results_dir)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/run.py*

## Running the pipeline

Submit the pipeline with a single command from the project directory:

```bash
uv run run.py
```

This calls:
```
    run = flyte.with_runcontext(
        custom_context=wandb_config(
            project=training_config.wandb_project,
            entity=training_config.wandb_entity,
        ),
    ).run(satellite_classification_pipeline, config=training_config)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/satellite_image_classification/run.py*

The W&B project and entity are wired in at submission time. Union handles spinning up the right containers, routing data between tasks, and surfacing results in the UI.

## What you get

After the pipeline completes:

- Union UI: a report panel with interactive accuracy and loss curves, phase transition marker, and full task logs for each stage.

  ![Validation Accuracy](https://www.union.ai/docs/v2/union/_static/images/tutorials/satellite_image_classification/validation_accuracy.png)

  ![Loss](https://www.union.ai/docs/v2/union/_static/images/tutorials/satellite_image_classification/loss.png)

- Weights & Biases: a complete experiment run with validation metrics like loss and accuracy, train loss, and t-SNE visualizations of the model's learned embeddings at configurable epoch intervals. Every few epochs, a t-SNE plot of the validation set embeddings is logged, showing how the model's feature representations evolve over training. Classes that start as an overlapping cloud gradually pull apart into tight, well-separated clusters as the backbone learns satellite-specific features.

  ![t-SNE Visualization](https://www.union.ai/docs/v2/union/_static/images/tutorials/satellite_image_classification/tsne.gif)

- Model checkpoints: Lightning's `ModelCheckpoint` saves the top 3 best-performing checkpoints by validation accuracy, named `best-{epoch}-{val_acc}.ckpt`. These are standard PyTorch Lightning checkpoints that can be loaded directly for inference.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/financial-services ===

# Financial services & fintech

Tutorials for financial research, trading, and other fintech workloads.

### **Financial services & fintech > Financial research agent**

Prep equity briefings for the earnings cycle with grounded You.com Research synthesis and fresh news from the Search API.

### **Financial services & fintech > Fraud detection with Feast**

Train an XGBoost fraud classifier and materialize transaction features in Feast for online scoring.

### **Financial services & fintech > Multi-agent trading simulation**

A multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/financial-services/trading-agents ===

# Multi-agent trading simulation

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/trading_agents); based on work by [TauricResearch](https://github.com/TauricResearch/TradingAgents).

This example walks you through building a multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

![Trading agents execution visualization](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/trading-agents/execution.png)
_Trading agents execution visualization_

## TL;DR

- You'll build a trading firm made up of agents that analyze, argue, and act, modeled with Python functions.
- You'll use the Flyte SDK to orchestrate this world — giving you visibility, retries, caching, and durability.
- You'll learn how to plug in tools, structure conversations, and track decisions across agents.
- You'll see how agents debate, use context, generate reports, and retain memory via vector DBs.

## What is an agent, anyway?

Agentic workflows are a rising pattern for complex problem-solving with LLMs. Think of agents as:

- An LLM (like GPT-4 or Mistral)
- A loop that keeps them thinking until a goal is met
- A set of optional tools they can call (APIs, search, calculators, etc.)
- Enough tokens to reason about the problem at hand

That's it.

You define tools, bind them to an agent, and let it run, reasoning step-by-step, optionally using those tools, until it finishes.

## What's different here?

We're not building yet another agent framework. You're free to use LangChain, custom code, or whatever setup you like.

What we're giving you is the missing piece: a way to run these workflows **reliably, observably, and at scale, with zero rewrites.**

With Flyte, you get:

- Prompt + tool traceability and full state retention
- Built-in retries, caching, and failure recovery
- A native way to plug in your agents; no magic syntax required

## How it works: step-by-step walkthrough

This simulation is powered by a Flyte task that orchestrates multiple intelligent agents working together to analyze a company's stock and make informed trading decisions.

![Trading agents schema](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/trading-agents/schema.png)
_Trading agents schema_

### Entry point

Everything begins with a top-level Flyte task called `main`, which serves as the entry point to the workflow.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

This task accepts several inputs:

- the list of analysts to run,
- the number of debate and risk discussion rounds,
- a flag to enable online tools,
- the company you're evaluating,
- and the target trading date.

The most interesting parameter here is the list of analysts to run. It determines which analyst agents will be invoked and shapes the overall structure of the simulation. Based on this input, the task dynamically launches agent tasks, running them in parallel.

The `main` task is written as a regular asynchronous Python function wrapped with Flyte's task decorator. No domain-specific language or orchestration glue is needed — just idiomatic Python, optionally using async for better performance. The task environment is configured once and shared across all tasks for consistency.

```
# {{docs-fragment env}}
import flyte

QUICK_THINKING_LLM = "gpt-4o-mini"
DEEP_THINKING_LLM = "o4-mini"

env = flyte.TaskEnvironment(
    name="trading-agents",
    secrets=[
        flyte.Secret(key="finnhub_api_key", as_env_var="FINNHUB_API_KEY"),
        flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY"),
    ],
    image=flyte.Image.from_uv_script("main.py", name="trading-agents", pre=True),
    resources=flyte.Resources(cpu="1"),
    cache="auto",
)

# {{/docs-fragment env}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/flyte_env.py*

### Analyst agents

Each analyst agent comes equipped with a set of tools and a carefully designed prompt tailored to its specific domain. These tools are modular Flyte tasks — for example, downloading financial reports or computing technical indicators — and benefit from Flyte's built-in caching to avoid redundant computation.

```
from datetime import datetime

import pandas as pd
import tools.interface as interface
import yfinance as yf
from flyte_env import env

from flyte.io import File

@env.task
async def get_reddit_news(
    curr_date: str,  # Date you want to get news for in yyyy-mm-dd format
) -> str:
    """
    Retrieve global news from Reddit within a specified time frame.
    Args:
        curr_date (str): Date you want to get news for in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the latest global news
        from Reddit in the specified time frame.
    """

    global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)

    return global_news_result

@env.task
async def get_finnhub_news(
    ticker: str,  # Search query of a company, e.g. 'AAPL, TSM, etc.
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news about a given stock from Finnhub within a date range
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing news about the company
        within the date range from start_date to end_date
    """

    end_date_str = end_date

    end_date = datetime.strptime(end_date, "%Y-%m-%d")
    start_date = datetime.strptime(start_date, "%Y-%m-%d")
    look_back_days = (end_date - start_date).days

    finnhub_news_result = interface.get_finnhub_news(
        ticker, end_date_str, look_back_days
    )

    return finnhub_news_result

@env.task
async def get_reddit_stock_info(
    ticker: str,  # Ticker of a company. e.g. AAPL, TSM
    curr_date: str,  # Current date you want to get news for
) -> str:
    """
    Retrieve the latest news about a given stock from Reddit, given the current date.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): current date in yyyy-mm-dd format to get news for
    Returns:
        str: A formatted dataframe containing the latest news about the company on the given date
    """

    stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)

    return stock_news_results

@env.task
async def get_YFin_data(
    symbol: str,  # ticker symbol of the company
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the stock price data
        for the specified ticker symbol in the specified date range.
    """

    result_data = interface.get_YFin_data(symbol, start_date, end_date)

    return result_data

@env.task
async def get_YFin_data_online(
    symbol: str,  # ticker symbol of the company
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the stock price data
        for the specified ticker symbol in the specified date range.
    """

    result_data = interface.get_YFin_data_online(symbol, start_date, end_date)

    return result_data

@env.task
async def cache_market_data(symbol: str, start_date: str, end_date: str) -> File:
    data_file = f"{symbol}-YFin-data-{start_date}-{end_date}.csv"

    data = yf.download(
        symbol,
        start=start_date,
        end=end_date,
        multi_level_index=False,
        progress=False,
        auto_adjust=True,
    )
    data = data.reset_index()
    data.to_csv(data_file, index=False)

    return await File.from_local(data_file)

@env.task
async def get_stockstats_indicators_report(
    symbol: str,  # ticker symbol of the company
    indicator: str,  # technical indicator to get the analysis and report of
    curr_date: str,  # The current trading date you are trading on, YYYY-mm-dd
    look_back_days: int = 30,  # how many days to look back
) -> str:
    """
    Retrieve stock stats indicators for a given ticker symbol and indicator.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        indicator (str): Technical indicator to get the analysis and report of
        curr_date (str): The current trading date you are trading on, YYYY-mm-dd
        look_back_days (int): How many days to look back, default is 30
    Returns:
        str: A formatted dataframe containing the stock stats indicators
        for the specified ticker symbol and indicator.
    """

    today_date = pd.Timestamp.today()

    end_date = today_date
    start_date = today_date - pd.DateOffset(years=15)
    start_date = start_date.strftime("%Y-%m-%d")
    end_date = end_date.strftime("%Y-%m-%d")

    data_file = await cache_market_data(symbol, start_date, end_date)
    local_data_file = await data_file.download()

    result_stockstats = interface.get_stock_stats_indicators_window(
        symbol, indicator, curr_date, look_back_days, False, local_data_file
    )

    return result_stockstats

# {{docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_stockstats_indicators_report_online(
    symbol: str,  # ticker symbol of the company
    indicator: str,  # technical indicator to get the analysis and report of
    curr_date: str,  # The current trading date you are trading on, YYYY-mm-dd"
    look_back_days: int = 30,  # "how many days to look back"
) -> str:
    """
    Retrieve stock stats indicators for a given ticker symbol and indicator.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        indicator (str): Technical indicator to get the analysis and report of
        curr_date (str): The current trading date you are trading on, YYYY-mm-dd
        look_back_days (int): How many days to look back, default is 30
    Returns:
        str: A formatted dataframe containing the stock stats indicators
        for the specified ticker symbol and indicator.
    """

    today_date = pd.Timestamp.today()

    end_date = today_date
    start_date = today_date - pd.DateOffset(years=15)
    start_date = start_date.strftime("%Y-%m-%d")
    end_date = end_date.strftime("%Y-%m-%d")

    data_file = await cache_market_data(symbol, start_date, end_date)
    local_data_file = await data_file.download()

    result_stockstats = interface.get_stock_stats_indicators_window(
        symbol, indicator, curr_date, look_back_days, True, local_data_file
    )

    return result_stockstats

# {{/docs-fragment get_stockstats_indicators_report_online}}

@env.task
async def get_finnhub_company_insider_sentiment(
    ticker: str,  # ticker symbol for the company
    curr_date: str,  # current date of you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve insider sentiment information about a company (retrieved
    from public SEC information) for the past 30 days
    Args:
        ticker (str): ticker symbol of the company
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the sentiment in the past 30 days starting at curr_date
    """

    data_sentiment = interface.get_finnhub_company_insider_sentiment(
        ticker, curr_date, 30
    )

    return data_sentiment

@env.task
async def get_finnhub_company_insider_transactions(
    ticker: str,  # ticker symbol
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve insider transaction information about a company
    (retrieved from public SEC information) for the past 30 days
    Args:
        ticker (str): ticker symbol of the company
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the company's insider transactions/trading information in the past 30 days
    """

    data_trans = interface.get_finnhub_company_insider_transactions(
        ticker, curr_date, 30
    )

    return data_trans

@env.task
async def get_simfin_balance_sheet(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
):
    """
    Retrieve the most recent balance sheet of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the company's most recent balance sheet
    """

    data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)

    return data_balance_sheet

@env.task
async def get_simfin_cashflow(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve the most recent cash flow statement of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
            str: a report of the company's most recent cash flow statement
    """

    data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)

    return data_cashflow

@env.task
async def get_simfin_income_stmt(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve the most recent income statement of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
            str: a report of the company's most recent income statement
    """

    data_income_stmt = interface.get_simfin_income_statements(ticker, freq, curr_date)

    return data_income_stmt

@env.task
async def get_google_news(
    query: str,  # Query to search with
    curr_date: str,  # Curr date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news from Google News based on a query and date range.
    Args:
        query (str): Query to search with
        curr_date (str): Current date in yyyy-mm-dd format
        look_back_days (int): How many days to look back
    Returns:
        str: A formatted string containing the latest news from Google News
        based on the query and date range.
    """

    google_news_results = interface.get_google_news(query, curr_date, 7)

    return google_news_results

@env.task
async def get_stock_news_openai(
    ticker: str,  # the company's ticker
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news about a given stock by using OpenAI's news API.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): Current date in yyyy-mm-dd format
    Returns:
        str: A formatted string containing the latest news about the company on the given date.
    """

    openai_news_results = interface.get_stock_news_openai(ticker, curr_date)

    return openai_news_results

@env.task
async def get_global_news_openai(
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
    Args:
        curr_date (str): Current date in yyyy-mm-dd format
    Returns:
        str: A formatted string containing the latest macroeconomic news on the given date.
    """

    openai_news_results = interface.get_global_news_openai(curr_date)

    return openai_news_results

@env.task
async def get_fundamentals_openai(
    ticker: str,  # the company's ticker
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest fundamental information about a given stock
    on a given date by using OpenAI's news API.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): Current date in yyyy-mm-dd format

    Returns:
        str: A formatted string containing the latest fundamental information
        about the company on the given date.
    """

    openai_fundamentals_results = interface.get_fundamentals_openai(ticker, curr_date)

    return openai_fundamentals_results
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/tools/toolkit.py*

When initialized, an analyst enters a structured reasoning loop (via LangChain), where it can call tools, observe outputs, and refine its internal state before generating a final report. These reports are later consumed by downstream agents.

Here's an example of a news analyst that interprets global events and macroeconomic signals. We specify the tools accessible to the analyst, and the LLM selects which ones to use based on context.

```
import asyncio

from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit

import flyte

MAX_ITERATIONS = 5

# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
    type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK; another assistant with different tools"
                " will help where you left off. Execute what you can to make progress."
                " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
                " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}"
                " For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join(tool_names))
    prompt = prompt.partial(current_date=state.trade_date)
    prompt = prompt.partial(ticker=state.company_of_interest)

    chain = prompt | ChatOpenAI(model=llm).bind_tools(
        [getattr(toolkit, tool_name).func for tool_name in tool_names]
    )

    iteration = 0
    while iteration < MAX_ITERATIONS:
        result = await chain.ainvoke(state.messages)
        state.messages.append(convert_to_openai_messages(result))

        if not result.tool_calls:
            # Final response — no tools required
            setattr(state, f"{type}_report", result.content or "")
            break

        # Run all tool calls in parallel
        async def run_single_tool(tool_call):
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool = getattr(toolkit, tool_name, None)

            if not tool:
                return None

            content = await tool(**tool_args)
            return ToolMessage(
                tool_call_id=tool_call["id"], name=tool_name, content=content
            )

        with flyte.group(f"tool_calls_iteration_{iteration}"):
            tool_messages = await asyncio.gather(
                *[run_single_tool(tc) for tc in result.tool_calls]
            )

        # Add valid tool results to state
        tool_messages = [msg for msg in tool_messages if msg]
        state.messages.extend(convert_to_openai_messages(tool_messages))

        iteration += 1
    else:
        # Reached iteration cap — optionally raise or log
        print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")

    return state

# {{/docs-fragment agent_helper}}

@env.task
async def create_fundamentals_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_fundamentals_openai]
    else:
        tools = [
            toolkit.get_finnhub_company_insider_sentiment,
            toolkit.get_finnhub_company_insider_transactions,
            toolkit.get_simfin_balance_sheet,
            toolkit.get_simfin_cashflow,
            toolkit.get_simfin_income_stmt,
        ]

    system_message = (
        "You are a researcher tasked with analyzing fundamental information over the past week about a company. "
        "Please write a comprehensive report of the company's fundamental information such as financial documents, "
        "company profile, basic company financials, company financial history, insider sentiment, and insider "
        "transactions to gain a full view of the company's "
        "fundamental information to inform traders. Make sure to include as much detail as possible. "
        "Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions. "
        "Make sure to append a Markdown table at the end of the report to organize key points in the report, "
        "organized and easy to read."
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "fundamentals", state, llm, system_message, tool_names
    )

@env.task
async def create_market_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_YFin_data_online,
            toolkit.get_stockstats_indicators_report_online,
        ]
    else:
        tools = [
            toolkit.get_YFin_data,
            toolkit.get_stockstats_indicators_report,
        ]

    system_message = (
        """You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:

Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.

MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.

Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.

Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.

Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.

- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
        """ Make sure to append a Markdown table at the end of the report to
        organize key points in the report, organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]
    return await run_chain_with_tools("market", state, llm, system_message, tool_names)

# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_global_news_openai,
            toolkit.get_google_news,
        ]
    else:
        tools = [
            toolkit.get_finnhub_news,
            toolkit.get_reddit_news,
            toolkit.get_google_news,
        ]

    system_message = (
        "You are a news researcher tasked with analyzing recent news and trends over the past week. "
        "Please write a comprehensive report of the current state of the world that is relevant for "
        "trading and macroeconomics. "
        "Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Markdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools("news", state, llm, system_message, tool_names)

# {{/docs-fragment news_analyst}}

@env.task
async def create_social_media_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_stock_news_openai]
    else:
        tools = [toolkit.get_reddit_stock_info]

    system_message = (
        "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
        "recent company news, and public sentiment for a specific company over the past week. "
        "You will be given a company's name your objective is to write a comprehensive long report "
        "detailing your analysis, insights, and implications for traders and investors on this company's current state "
        "after looking at social media and what people are saying about that company, "
        "analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
        "Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
        "are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Makrdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "sentiment", state, llm, system_message, tool_names
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*

Each analyst agent uses a helper function to bind tools, iterate through reasoning steps (up to a configurable maximum), and produce an answer. Setting a max iteration count is crucial to prevent runaway loops. As agents reason, their message history is preserved in their internal state and passed along to the next agent in the chain.

```
import asyncio

from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit

import flyte

MAX_ITERATIONS = 5

# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
    type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK; another assistant with different tools"
                " will help where you left off. Execute what you can to make progress."
                " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
                " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}"
                " For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join(tool_names))
    prompt = prompt.partial(current_date=state.trade_date)
    prompt = prompt.partial(ticker=state.company_of_interest)

    chain = prompt | ChatOpenAI(model=llm).bind_tools(
        [getattr(toolkit, tool_name).func for tool_name in tool_names]
    )

    iteration = 0
    while iteration < MAX_ITERATIONS:
        result = await chain.ainvoke(state.messages)
        state.messages.append(convert_to_openai_messages(result))

        if not result.tool_calls:
            # Final response — no tools required
            setattr(state, f"{type}_report", result.content or "")
            break

        # Run all tool calls in parallel
        async def run_single_tool(tool_call):
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool = getattr(toolkit, tool_name, None)

            if not tool:
                return None

            content = await tool(**tool_args)
            return ToolMessage(
                tool_call_id=tool_call["id"], name=tool_name, content=content
            )

        with flyte.group(f"tool_calls_iteration_{iteration}"):
            tool_messages = await asyncio.gather(
                *[run_single_tool(tc) for tc in result.tool_calls]
            )

        # Add valid tool results to state
        tool_messages = [msg for msg in tool_messages if msg]
        state.messages.extend(convert_to_openai_messages(tool_messages))

        iteration += 1
    else:
        # Reached iteration cap — optionally raise or log
        print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")

    return state

# {{/docs-fragment agent_helper}}

@env.task
async def create_fundamentals_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_fundamentals_openai]
    else:
        tools = [
            toolkit.get_finnhub_company_insider_sentiment,
            toolkit.get_finnhub_company_insider_transactions,
            toolkit.get_simfin_balance_sheet,
            toolkit.get_simfin_cashflow,
            toolkit.get_simfin_income_stmt,
        ]

    system_message = (
        "You are a researcher tasked with analyzing fundamental information over the past week about a company. "
        "Please write a comprehensive report of the company's fundamental information such as financial documents, "
        "company profile, basic company financials, company financial history, insider sentiment, and insider "
        "transactions to gain a full view of the company's "
        "fundamental information to inform traders. Make sure to include as much detail as possible. "
        "Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions. "
        "Make sure to append a Markdown table at the end of the report to organize key points in the report, "
        "organized and easy to read."
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "fundamentals", state, llm, system_message, tool_names
    )

@env.task
async def create_market_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_YFin_data_online,
            toolkit.get_stockstats_indicators_report_online,
        ]
    else:
        tools = [
            toolkit.get_YFin_data,
            toolkit.get_stockstats_indicators_report,
        ]

    system_message = (
        """You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:

Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.

MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.

Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.

Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.

Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.

- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
        """ Make sure to append a Markdown table at the end of the report to
        organize key points in the report, organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]
    return await run_chain_with_tools("market", state, llm, system_message, tool_names)

# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_global_news_openai,
            toolkit.get_google_news,
        ]
    else:
        tools = [
            toolkit.get_finnhub_news,
            toolkit.get_reddit_news,
            toolkit.get_google_news,
        ]

    system_message = (
        "You are a news researcher tasked with analyzing recent news and trends over the past week. "
        "Please write a comprehensive report of the current state of the world that is relevant for "
        "trading and macroeconomics. "
        "Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Markdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools("news", state, llm, system_message, tool_names)

# {{/docs-fragment news_analyst}}

@env.task
async def create_social_media_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_stock_news_openai]
    else:
        tools = [toolkit.get_reddit_stock_info]

    system_message = (
        "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
        "recent company news, and public sentiment for a specific company over the past week. "
        "You will be given a company's name your objective is to write a comprehensive long report "
        "detailing your analysis, insights, and implications for traders and investors on this company's current state "
        "after looking at social media and what people are saying about that company, "
        "analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
        "Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
        "are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Makrdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "sentiment", state, llm, system_message, tool_names
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*

Once all analyst reports are complete, their outputs are collected and passed to the next stage of the workflow.

### Research agents

The research phase consists of two agents: a bullish researcher and a bearish one. They evaluate the company from opposing viewpoints, drawing on the analysts' reports. Unlike analysts, they don't use tools. Their role is to interpret, critique, and develop positions based on the evidence.

```
from agents.utils.utils import AgentState, InvestmentDebateState, memory_init
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment bear_researcher}}
@env.task
async def create_bear_researcher(llm: str, state: AgentState) -> AgentState:
    investment_debate_state = state.investment_debate_state
    history = investment_debate_state.history
    bear_history = investment_debate_state.bear_history

    current_response = investment_debate_state.current_response
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="bear-researcher")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""You are a Bear Analyst making the case against investing in the stock.
Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators.
Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.

Key points to focus on:

- Risks and Challenges: Highlight factors like market saturation, financial instability,
or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation,
or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning,
exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points
and debating effectively rather than simply listing facts.

Resources available:

Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate
that demonstrates the risks and weaknesses of investing in the stock.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Bear Analyst: {response.content}"

    new_investment_debate_state = InvestmentDebateState(
        history=history + "\n" + argument,
        bear_history=bear_history + "\n" + argument,
        bull_history=investment_debate_state.bull_history,
        current_response=argument,
        count=investment_debate_state.count + 1,
    )

    state.investment_debate_state = new_investment_debate_state
    return state

# {{/docs-fragment bear_researcher}}

@env.task
async def create_bull_researcher(llm: str, state: AgentState) -> AgentState:
    investment_debate_state = state.investment_debate_state
    history = investment_debate_state.history
    bull_history = investment_debate_state.bull_history

    current_response = investment_debate_state.current_response
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="bull-researcher")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""You are a Bull Analyst advocating for investing in the stock.
Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages,
and positive market indicators.
Leverage the provided research and data to address concerns and counter bearish arguments effectively.

Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing
concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points
and debating effectively rather than just listing data.

Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate
that demonstrates the strengths of the bull position.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Bull Analyst: {response.content}"

    new_investment_debate_state = InvestmentDebateState(
        history=history + "\n" + argument,
        bull_history=bull_history + "\n" + argument,
        bear_history=investment_debate_state.bear_history,
        current_response=argument,
        count=investment_debate_state.count + 1,
    )

    state.investment_debate_state = new_investment_debate_state
    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/researchers.py*

To aid reasoning, the agents can also retrieve relevant "memories" from a vector database, giving them richer historical context. The number of debate rounds is configurable, and after a few iterations of back-and-forth between the bull and bear, a research manager agent reviews their arguments and makes a final investment decision.

```
from agents.utils.utils import (
    AgentState,
    InvestmentDebateState,
    RiskDebateState,
    memory_init,
)
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment research_manager}}
@env.task
async def create_research_manager(llm: str, state: AgentState) -> AgentState:
    history = state.investment_debate_state.history
    investment_debate_state = state.investment_debate_state
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="research-manager")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate
this round of debate and make a definitive decision:
align with the bear analyst, the bull analyst,
or choose Hold only if it is strongly justified based on the arguments presented.

Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning.
Your recommendation—Buy, Sell, or Hold—must be clear and actionable.
Avoid defaulting to Hold simply because both sides have valid points;
commit to a stance grounded in the debate's strongest arguments.

Additionally, develop a detailed investment plan for the trader. This should include:

Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations.
Use these insights to refine your decision-making and ensure you are learning and improving.
Present your analysis conversationally, as if speaking naturally, without special formatting.

Here are your past reflections on mistakes:
\"{past_memory_str}\"

Here is the debate:
Debate History:
{history}"""
    response = ChatOpenAI(model=llm).invoke(prompt)

    new_investment_debate_state = InvestmentDebateState(
        judge_decision=response.content,
        history=investment_debate_state.history,
        bear_history=investment_debate_state.bear_history,
        bull_history=investment_debate_state.bull_history,
        current_response=response.content,
        count=investment_debate_state.count,
    )

    state.investment_debate_state = new_investment_debate_state
    state.investment_plan = response.content

    return state

# {{/docs-fragment research_manager}}

@env.task
async def create_risk_manager(llm: str, state: AgentState) -> AgentState:
    history = state.risk_debate_state.history
    risk_debate_state = state.risk_debate_state
    trader_plan = state.investment_plan
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="risk-manager")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""As the Risk Management Judge and Debate Facilitator,
your goal is to evaluate the debate between three risk analysts—Risky,
Neutral, and Safe/Conservative—and determine the best course of action for the trader.
Your decision must result in a clear recommendation: Buy, Sell, or Hold.
Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid.
Strive for clarity and decisiveness.

Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**,
and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments
and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.

Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.

---

**Analysts Debate History:**
{history}

---

Focus on actionable insights and continuous improvement.
Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    new_risk_debate_state = RiskDebateState(
        judge_decision=response.content,
        history=risk_debate_state.history,
        risky_history=risk_debate_state.risky_history,
        safe_history=risk_debate_state.safe_history,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Judge",
        current_risky_response=risk_debate_state.current_risky_response,
        current_safe_response=risk_debate_state.current_safe_response,
        current_neutral_response=risk_debate_state.current_neutral_response,
        count=risk_debate_state.count,
    )

    state.risk_debate_state = new_risk_debate_state
    state.final_trade_decision = response.content

    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/managers.py*

### Trading agent

The trader agent consolidates the insights from analysts and researchers to generate a final recommendation. It synthesizes competing signals and produces a conclusion such as _Buy for long-term growth despite short-term volatility_.

```
from agents.utils.utils import AgentState, memory_init
from flyte_env import env
from langchain_core.messages import convert_to_openai_messages
from langchain_openai import ChatOpenAI

# {{docs-fragment trader}}
@env.task
async def create_trader(llm: str, state: AgentState) -> AgentState:
    company_name = state.company_of_interest
    investment_plan = state.investment_plan
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="trader")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    context = {
        "role": "user",
        "content": f"Based on a comprehensive analysis by a team of analysts, "
        f"here is an investment plan tailored for {company_name}. "
        "This plan incorporates insights from current technical market trends, "
        "macroeconomic indicators, and social media sentiment. "
        "Use this plan as a foundation for evaluating your next trading decision.\n\n"
        f"Proposed Investment Plan: {investment_plan}\n\n"
        "Leverage these insights to make an informed and strategic decision.",
    }

    messages = [
        {
            "role": "system",
            "content": f"""You are a trading agent analyzing market data to make investment decisions.
Based on your analysis, provide a specific recommendation to buy, sell, or hold.
End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
to confirm your recommendation.
Do not forget to utilize lessons from past decisions to learn from your mistakes.
Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
        },
        context,
    ]

    result = ChatOpenAI(model=llm).invoke(messages)

    state.messages.append(convert_to_openai_messages(result))
    state.trader_investment_plan = result.content
    state.sender = "Trader"

    return state

# {{/docs-fragment trader}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/trader.py*

### Risk agents

Risk agents comprise agents with different risk tolerances: a risky debater, a neutral one, and a conservative one. They assess the portfolio through lenses like market volatility, liquidity, and systemic risk. Similar to the bull-bear debate, these agents engage in internal discussion, after which a risk manager makes the final call.

```
from agents.utils.utils import AgentState, RiskDebateState
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment risk_debator}}
@env.task
async def create_risky_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    risky_history = risk_debate_state.risky_history

    current_safe_response = risk_debate_state.current_safe_response
    current_neutral_response = risk_debate_state.current_neutral_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities,
emphasizing bold strategies and competitive advantages.
When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential,
and innovative benefits—even when these come with elevated risk.
Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views.
Specifically, respond directly to each point made by the conservative and neutral analysts,
countering with data-driven rebuttals and persuasive reasoning.
Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative.
Here is the trader's decision:

{trader_decision}

Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative
and neutral stances to demonstrate why your high-reward perspective offers the best path forward.
Incorporate insights from the following sources into your arguments:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here are the last arguments from the conservative analyst: {current_safe_response}
Here are the last arguments from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic,
and asserting the benefits of risk-taking to outpace market norms.
Maintain a focus on debating and persuading, not just presenting data.
Challenge each counterpoint to underscore why a high-risk approach is optimal.
Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Risky Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risky_history + "\n" + argument,
        safe_history=risk_debate_state.safe_history,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Risky",
        current_risky_response=argument,
        current_safe_response=current_safe_response,
        current_neutral_response=current_neutral_response,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state

# {{/docs-fragment risk_debator}}

@env.task
async def create_safe_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    safe_history = risk_debate_state.safe_history

    current_risky_response = risk_debate_state.current_risky_response
    current_neutral_response = risk_debate_state.current_neutral_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets,
minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation,
carefully assessing potential losses, economic downturns, and market volatility.
When evaluating the trader's decision or plan, critically examine high-risk elements,
pointing out where the decision may expose the firm to undue risk and where more cautious
alternatives could secure long-term gains.
Here is the trader's decision:

{trader_decision}

Your task is to actively counter the arguments of the Risky and Neutral Analysts,
highlighting where their views may overlook potential threats or fail to prioritize sustainability.
Respond directly to their points, drawing from the following data sources
to build a convincing case for a low-risk approach adjustment to the trader's decision:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked.
Address each of their counterpoints to showcase why a conservative stance is ultimately the
safest path for the firm's assets.
Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy
over their approaches.
Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Safe Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risk_debate_state.risky_history,
        safe_history=safe_history + "\n" + argument,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Safe",
        current_risky_response=current_risky_response,
        current_safe_response=argument,
        current_neutral_response=current_neutral_response,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state

@env.task
async def create_neutral_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    neutral_history = risk_debate_state.neutral_history

    current_risky_response = risk_debate_state.current_risky_response
    current_safe_response = risk_debate_state.current_safe_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective,
weighing both the potential benefits and risks of the trader's decision or plan.
You prioritize a well-rounded approach, evaluating the upsides
and downsides while factoring in broader market trends,
potential economic shifts, and diversification strategies.Here is the trader's decision:

{trader_decision}

Your task is to challenge both the Risky and Safe Analysts,
pointing out where each perspective may be overly optimistic or overly cautious.
Use insights from the following data sources to support a moderate, sustainable strategy
to adjust the trader's decision:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the safe analyst: {current_safe_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage actively by analyzing both sides critically, addressing weaknesses in the risky
and conservative arguments to advocate for a more balanced approach.
Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds,
providing growth potential while safeguarding against extreme volatility.
Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to
the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Neutral Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risk_debate_state.risky_history,
        safe_history=risk_debate_state.safe_history,
        neutral_history=neutral_history + "\n" + argument,
        latest_speaker="Neutral",
        current_risky_response=current_risky_response,
        current_safe_response=current_safe_response,
        current_neutral_response=argument,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/risk_debators.py*

The outcome of the risk manager — whether to proceed with the trade or not — is considered the final decision of the trading simulation.

You can visualize this full pipeline in the Flyte/Union UI, where every step is logged.
You’ll see input/output metadata for each tool and agent task.
Thanks to Flyte's caching, repeated steps are skipped unless inputs change, saving time and compute resources.

### Retaining agent memory with S3 vectors

To help agents learn from past decisions, we persist their memory in a vector store. In this example, we use an [S3 vector](https://aws.amazon.com/s3/features/vectors/) bucket for their simplicity and tight integration with Flyte and Union, but any vector database can be used.

Note: To use the S3 vector store, make sure your IAM role has the following permissions configured:

```
s3vectors:CreateVectorBucket
s3vectors:CreateIndex
s3vectors:PutVectors
s3vectors:GetIndex
s3vectors:GetVectors
s3vectors:QueryVectors
s3vectors:GetVectorBucket
```

After each trade decision, you can run a `reflect_on_decisions` task. This evaluates whether the final outcome aligned with the agent's recommendation and stores that reflection in the vector store. These stored insights can later be retrieved to provide historical context and improve future decision-making.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

### Running the simulation

First, set up your OpenAI secret (from [openai.com](https://platform.openai.com/api-keys)) and Finnhub API key (from [finnhub.io](https://finnhub.io/)):

```
flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
flyte create secret finnhub_api_key <YOUR_FINNHUB_API_KEY>
```

Then [clone the repo](https://github.com/unionai/unionai-examples), navigate to the `tutorials-v2/trading_agents` directory, and run the following commands:

```
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run main.py
```

If you'd like to run the `reflect_on_decisions` task instead, comment out the `main` function call and uncomment the `reflect_on_decisions` call in the `__main__` block:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

Then run:

```
uv run main.py
```

## Why Flyte? _(A quick note before you go)_

You might now be wondering: can't I just build all this with Python and LangChain?
Absolutely. But as your project grows, you'll likely run into these challenges:

1.  **Observability**: Agent workflows can feel opaque. You send a prompt, get a response, but what happened in between?

    - Were the right tools used?
    - Were correct arguments passed?
    - How did the LLM reason through intermediate steps?
    - Why did it fail?

    Flyte gives you a window into each of these stages.

2.  **Multi-agent coordination**: Real-world applications often require multiple agents with distinct roles and responsibilities. In such cases, you'll need:

    - Isolated state per agent,
    - Shared context where needed,
    - And coordination — sequential or parallel.

    Managing this manually gets fragile, fast. Flyte handles it for you.

3.  **Scalability**: Agents and tools might need to run in isolated or containerized environments. Whether you're scaling out to more agents or more powerful hardware, Flyte lets you scale without taxing your local machine or racking up unnecessary cloud bills.
4.  **Durability & recovery**: LLM-based workflows are often long-running and expensive. If something fails halfway:

    - Do you lose all progress?
    - Replay everything from scratch?

    With Flyte, you get built-in caching, checkpointing, and recovery, so you can resume where you left off.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/financial-services/fraud-detection-feast ===

# Fraud detection with Feast

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/fraud_detection_feast).

This tutorial builds a credit-card fraud detection pipeline that combines [Feast](https://feast.dev/) feature store materialization with an XGBoost classifier on the Sparkov simulated transactions dataset. The workflow engineers transaction and user-level features, trains a model, registers features in Feast, and materializes online feature values for low-latency scoring.

Flyte provides:

- **Cached data preparation** for the Kaggle dataset download and feature engineering.
- **Report-backed training** with confusion matrix and ROC-style metrics in the UI.
- **Durable artifacts** — the trained model and Feast repo are returned as `flyte.io.File` and `flyte.io.Dir`.

## Define the task environment

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "feast==0.63.0",
#    "scikit-learn==1.8.0",
#    "xgboost==3.2.0",
#    "joblib",
#    "pandas",
#    "pyarrow",
#    "kagglehub==0.3.12",
# ]
# main = "fraud_detection_pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import shutil
import tempfile
from datetime import datetime, timedelta, timezone

import joblib
import numpy as np
import pandas as pd
import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="fraud-detection-feast", pre=True)

env = flyte.TaskEnvironment(
    name="fraud-detection-feast",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)
# {{/docs-fragment env}}

import report_helpers as rh

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Feature definitions
#
# Transaction features: known at scoring time (from the request)
# User features: pre-computed aggregates stored in Feast
# Derived features: computed at both training and scoring time by
#                   comparing the transaction to the user's profile
# ------------------------------------------------------------------

TXN_FEATURE_COLS = ["amt", "amt_log", "category_encoded", "merch_lat", "merch_long"]

USER_FEATURE_COLS = [
    "txn_count", "mean_amt", "std_amt", "max_amt",
    "home_lat", "home_long", "age",
]

DERIVED_FEATURE_COLS = [
    "amt_zscore", "amt_ratio", "distance_from_home", "hour", "day_of_week",
]

ALL_FEATURE_COLS = TXN_FEATURE_COLS + USER_FEATURE_COLS + DERIVED_FEATURE_COLS

def haversine(lat1, lon1, lat2, lon2):
    """Compute distance in miles between two (lat, lon) points."""
    R = 3959  # Earth radius in miles
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
    return 2 * R * np.arcsin(np.sqrt(a))

# ------------------------------------------------------------------
# Task 1: Download dataset and engineer features
# ------------------------------------------------------------------

@env.task(report=True, cache="auto")
async def prepare_data() -> flyte.io.Dir:
    """Download the Sparkov credit card fraud dataset and prepare parquets."""
    import kagglehub

    log.info("Downloading dataset...")
    dataset_path = kagglehub.dataset_download("kartik2112/fraud-detection")
    csv_path = os.path.join(dataset_path, "fraudTrain.csv")
    df = pd.read_csv(csv_path)
    log.info(f"Loaded {len(df):,} transactions ({int(df['is_fraud'].sum()):,} fraudulent)")

    # Sample for workshop speed (stratified to preserve fraud ratio)
    if len(df) > 500_000:
        from sklearn.model_selection import train_test_split
        df, _ = train_test_split(df, train_size=500_000, stratify=df["is_fraud"], random_state=42)
        log.info(f"Sampled to {len(df):,} transactions")

    # ------------------------------------------------------------------
    # Parse timestamps
    # ------------------------------------------------------------------
    df["event_timestamp"] = pd.to_datetime(df["trans_date_trans_time"])
    df["event_timestamp"] = df["event_timestamp"].dt.tz_localize("UTC")
    df["hour"] = df["event_timestamp"].dt.hour
    df["day_of_week"] = df["event_timestamp"].dt.dayofweek

    # ------------------------------------------------------------------
    # Map cc_num → sequential user_id for clean API
    # ------------------------------------------------------------------
    cc_nums = df["cc_num"].unique()
    cc_to_user = {cc: i for i, cc in enumerate(sorted(cc_nums))}
    df["user_id"] = df["cc_num"].map(cc_to_user)

    # ------------------------------------------------------------------
    # Feature engineering
    # ------------------------------------------------------------------
    df["amt_log"] = np.log1p(df["amt"])

    # Label-encode merchant category
    categories = sorted(df["category"].unique())
    cat_to_int = {cat: i for i, cat in enumerate(categories)}
    df["category_encoded"] = df["category"].map(cat_to_int)

    # Compute age from dob
    df["dob"] = pd.to_datetime(df["dob"]).dt.tz_localize("UTC")
    ref_date = df["event_timestamp"].max()
    df["age"] = ((ref_date - df["dob"]).dt.days / 365.25).astype(int)

    # Distance between buyer and merchant
    df["distance"] = haversine(df["lat"], df["long"], df["merch_lat"], df["merch_long"])

    # ------------------------------------------------------------------
    # Build user aggregates
    # ------------------------------------------------------------------
    user_stats = df.groupby("user_id").agg(
        txn_count=("amt", "count"),
        mean_amt=("amt", "mean"),
        std_amt=("amt", "std"),
        max_amt=("amt", "max"),
        home_lat=("lat", "median"),
        home_long=("long", "median"),
        age=("age", "first"),
    ).reset_index()
    user_stats["std_amt"] = user_stats["std_amt"].fillna(0)
    # Use earliest timestamp so Feast point-in-time joins work for all transactions
    earliest_ts = df.groupby("user_id")["event_timestamp"].min().reset_index()
    user_stats = user_stats.merge(earliest_ts, on="user_id")

    # ------------------------------------------------------------------
    # Save to temp directory
    # ------------------------------------------------------------------
    data_dir = tempfile.mkdtemp()

    txn_cols = [
        "user_id", "event_timestamp",
        "amt", "amt_log", "category_encoded", "merch_lat", "merch_long",
        "hour", "day_of_week", "lat", "long", "distance",
        "is_fraud",
    ]
    df[txn_cols].to_parquet(os.path.join(data_dir, "transactions.parquet"), index=False)
    user_stats.to_parquet(os.path.join(data_dir, "user_features.parquet"), index=False)

    # Save category mapping + cc_num mapping for the app
    with open(os.path.join(data_dir, "category_mapping.json"), "w") as f:
        json.dump(cat_to_int, f)
    with open(os.path.join(data_dir, "user_mapping.json"), "w") as f:
        json.dump({str(k): v for k, v in cc_to_user.items()}, f)

    n_fraud = int(df["is_fraud"].sum())
    n_legit = len(df) - n_fraud
    fraud_pct = df["is_fraud"].mean() * 100
    html = (
        '<h2>Data Prepared</h2>'
        + rh.stat_grid([
            (f"{len(df):,}", "Transactions"),
            (f"{n_fraud:,}", "Fraudulent"),
            (f"{fraud_pct:.2f}%", "Fraud Rate"),
            (f"{user_stats['user_id'].nunique():,}", "Users"),
            (f"{len(categories)}", "Categories"),
        ])
        + rh.class_distribution_bar(n_legit, n_fraud)
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    return await flyte.io.Dir.from_local(data_dir)

# ------------------------------------------------------------------
# Task 2: Set up Feast and materialize user profiles to online store
# ------------------------------------------------------------------

@env.task(report=True)
async def materialize_features(data_dir: flyte.io.Dir) -> flyte.io.Dir:
    """Apply Feast definitions and materialize user profiles to SQLite online store."""
    from feast import Entity, FeatureStore, FeatureView, Field, FileSource
    from feast.types import Float64, Int64

    data_path = await data_dir.download()

    # Create a self-contained Feast repo in a temp directory
    feast_dir = tempfile.mkdtemp()

    # Copy parquet into feast dir so the repo is fully self-contained
    shutil.copy2(
        os.path.join(data_path, "user_features.parquet"),
        os.path.join(feast_dir, "user_features.parquet"),
    )

    # Write feature_store.yaml
    yaml_content = (
        "project: fraud_detection\n"
        f"registry: {feast_dir}/registry.db\n"
        "provider: local\n"
        "online_store:\n"
        "  type: sqlite\n"
        f"  path: {feast_dir}/online_store.db\n"
        "offline_store:\n"
        "  type: file\n"
        "entity_key_serialization_version: 3\n"
    )
    yaml_path = os.path.join(feast_dir, "feature_store.yaml")
    with open(yaml_path, "w") as f:
        f.write(yaml_content)

    store = FeatureStore(repo_path=feast_dir)

    # Define entity and feature view
    user = Entity(name="user", join_keys=["user_id"], description="Credit card holder")

    user_source = FileSource(
        path=os.path.join(feast_dir, "user_features.parquet"),
        timestamp_field="event_timestamp",
    )

    user_stats = FeatureView(
        name="user_stats",
        entities=[user],
        ttl=timedelta(days=0),  # No expiry — workshop data has old timestamps
        schema=[
            Field(name="txn_count", dtype=Int64),
            Field(name="mean_amt", dtype=Float64),
            Field(name="std_amt", dtype=Float64),
            Field(name="max_amt", dtype=Float64),
            Field(name="home_lat", dtype=Float64),
            Field(name="home_long", dtype=Float64),
            Field(name="age", dtype=Int64),
        ],
        online=True,
        source=user_source,
    )

    # Apply and materialize
    log.info("Applying Feast definitions...")
    store.apply([user, user_stats])

    log.info("Materializing user profiles to online store...")
    store.materialize(
        start_date=datetime(2018, 1, 1, tzinfo=timezone.utc),
        end_date=datetime.now(timezone.utc),
    )

    # Re-apply with relative paths so the registry is portable across workers
    portable_yaml = (
        "project: fraud_detection\n"
        "registry: registry.db\n"
        "provider: local\n"
        "online_store:\n"
        "  type: sqlite\n"
        "  path: online_store.db\n"
        "offline_store:\n"
        "  type: file\n"
        "entity_key_serialization_version: 3\n"
    )
    with open(yaml_path, "w") as f:
        f.write(portable_yaml)

    # Re-apply with relative source path so get_historical_features works on other workers
    store = FeatureStore(repo_path=feast_dir)
    user_source = FileSource(
        path="user_features.parquet",
        timestamp_field="event_timestamp",
    )
    user_stats = FeatureView(
        name="user_stats",
        entities=[user],
        ttl=timedelta(days=0),
        schema=[
            Field(name="txn_count", dtype=Int64),
            Field(name="mean_amt", dtype=Float64),
            Field(name="std_amt", dtype=Float64),
            Field(name="max_amt", dtype=Float64),
            Field(name="home_lat", dtype=Float64),
            Field(name="home_long", dtype=Float64),
            Field(name="age", dtype=Int64),
        ],
        online=True,
        source=user_source,
    )
    store.apply([user, user_stats])

    features = ["txn_count", "mean_amt", "std_amt", "max_amt", "home_lat", "home_long", "age"]
    html = (
        '<h2>Feature Store Materialized</h2>'
        + rh.stat_grid([
            ("user_stats", "Feature View"),
            (str(len(features)), "Features"),
            ("SQLite", "Online Store"),
        ])
        + '<h3>Materialized Features</h3>'
        '<table>'
        '<tr><th>Feature</th><th>Type</th><th>Description</th></tr>'
        '<tr><td>txn_count</td><td><span class="badge badge-info">Int64</span></td><td>Total transactions</td></tr>'
        '<tr><td>mean_amt</td><td><span class="badge badge-info">Float64</span></td><td>Average transaction amount</td></tr>'
        '<tr><td>std_amt</td><td><span class="badge badge-info">Float64</span></td><td>Std dev of amounts</td></tr>'
        '<tr><td>max_amt</td><td><span class="badge badge-info">Float64</span></td><td>Max transaction amount</td></tr>'
        '<tr><td>home_lat</td><td><span class="badge badge-info">Float64</span></td><td>Home latitude (median)</td></tr>'
        '<tr><td>home_long</td><td><span class="badge badge-info">Float64</span></td><td>Home longitude (median)</td></tr>'
        '<tr><td>age</td><td><span class="badge badge-info">Int64</span></td><td>User age</td></tr>'
        '</table>'
        '<div class="note">User profiles are ready for real-time serving via the scoring app.</div>'
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    return await flyte.io.Dir.from_local(feast_dir)

# ------------------------------------------------------------------
# Task 3: Train XGBoost model
# ------------------------------------------------------------------

@env.task(report=True)
async def train_model(
    data_dir: flyte.io.Dir,
    feast_dir: flyte.io.Dir,
    n_estimators: int = 300,
    max_depth: int = 6,
    learning_rate: float = 0.1,
    min_child_weight: int = 5,
    gamma: float = 1.0,
) -> flyte.io.File:
    """Train an XGBoost classifier using Feast for feature retrieval."""
    from feast import FeatureStore
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
    from xgboost import XGBClassifier

    data_path = await data_dir.download()
    feast_path = await feast_dir.download()
    txn_df = pd.read_parquet(os.path.join(data_path, "transactions.parquet"))

    with open(os.path.join(data_path, "category_mapping.json")) as f:
        category_mapping = json.load(f)

    # Fetch user features from Feast (same path as serving)
    store = FeatureStore(repo_path=feast_path)
    entity_df = txn_df[["user_id", "event_timestamp"]].copy()

    log.info("Fetching user features from Feast (get_historical_features)...")
    training_data = store.get_historical_features(
        entity_df=entity_df,
        features=[
            "user_stats:txn_count",
            "user_stats:mean_amt",
            "user_stats:std_amt",
            "user_stats:max_amt",
            "user_stats:home_lat",
            "user_stats:home_long",
            "user_stats:age",
        ],
    ).to_df()

    # Merge back transaction features (Feast only returns user profile)
    training_data = training_data.merge(
        txn_df[["user_id", "event_timestamp", "amt", "amt_log", "category_encoded",
                "merch_lat", "merch_long", "hour", "day_of_week", "is_fraud"]],
        on=["user_id", "event_timestamp"],
        how="inner",
    )

    # Derived features: compare this transaction to the user's profile
    training_data["amt_zscore"] = (
        (training_data["amt"] - training_data["mean_amt"])
        / training_data["std_amt"].replace(0, 1)
    )
    training_data["amt_ratio"] = (
        training_data["amt"] / training_data["mean_amt"].replace(0, 1)
    )
    training_data["distance_from_home"] = haversine(
        training_data["home_lat"], training_data["home_long"],
        training_data["merch_lat"], training_data["merch_long"],
    )

    training_data = training_data.dropna(subset=ALL_FEATURE_COLS)
    X = training_data[ALL_FEATURE_COLS].values
    y = training_data["is_fraud"].values
    log.info(f"Training on {len(X):,} rows, {int(y.sum()):,} fraud")

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y,
    )

    n_legit = int((y_train == 0).sum())
    n_fraud = int((y_train == 1).sum())
    scale_pos_weight = n_legit / max(n_fraud, 1)

    model = XGBClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        scale_pos_weight=scale_pos_weight,
        min_child_weight=min_child_weight,
        gamma=gamma,
        random_state=42,
        eval_metric="logloss",
    )
    model.fit(X_train, y_train)

    # Evaluate
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1]
    auc = roc_auc_score(y_test, y_proba)
    cm = confusion_matrix(y_test, y_pred)
    report = classification_report(y_test, y_pred, target_names=["Legit", "Fraud"])

    log.info(f"AUC-ROC: {auc:.4f}")
    log.info(f"\n{report}")

    # Report
    precision_fraud = cm[1][1] / max(cm[1][1] + cm[0][1], 1) * 100
    recall_fraud = cm[1][1] / max(cm[1][1] + cm[1][0], 1) * 100

    html = (
        '<h2>Model Performance</h2>'
        + rh.stat_grid([
            (f"{auc:.4f}", "AUC-ROC"),
            (f"{len(X_train):,}", "Training Samples"),
            (f"{len(X_test):,}", "Test Samples"),
            (f"{precision_fraud:.1f}%", "Fraud Precision"),
            (f"{recall_fraud:.1f}%", "Fraud Recall"),
        ])
        + rh.confusion_matrix_html(cm)
    )

    # Feature importance bar chart
    importance = model.feature_importances_
    top_idx = np.argsort(importance)[::-1]
    top_labels = [ALL_FEATURE_COLS[i] for i in top_idx]
    top_values = [float(importance[i]) for i in top_idx]
    html += '<h3>Feature Importance</h3>'
    html += f'<div class="card">{rh.horizontal_bar_chart(top_labels, top_values)}</div>'

    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Save model + metadata
    model_path = os.path.join(tempfile.mkdtemp(), "model.joblib")
    joblib.dump({
        "model": model,
        "auc_roc": auc,
        "feature_cols": ALL_FEATURE_COLS,
        "category_mapping": category_mapping,
    }, model_path)

    return await flyte.io.File.from_local(model_path)

# ------------------------------------------------------------------
# Orchestrator: prepare → materialize → train
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def fraud_detection_pipeline(
    n_estimators: int = 300,
    max_depth: int = 6,
    learning_rate: float = 0.1,
    min_child_weight: int = 5,
    gamma: float = 1.0,
) -> tuple[flyte.io.File, flyte.io.Dir]:
    """
    Full fraud detection pipeline:
    1. Download and prepare data
    2. Materialize user profiles to Feast
    3. Train model using Feast for feature retrieval
    Returns model file and Feast artifacts for serving.
    """
    log.info("Starting fraud detection pipeline")
    steps = ["Prepare Data", "Materialize Features", "Train Model", "Done"]

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(0, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    data_dir = await prepare_data()

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(1, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Materialize features first so training can use Feast
    feast_dir = await materialize_features(data_dir)

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(2, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Train model using Feast for user feature retrieval
    model_file = await train_model(
        data_dir,
        feast_dir,
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        min_child_weight=min_child_weight,
        gamma=gamma,
    )

    # Save copies to working directory for local app testing
    model_local = await model_file.download()
    feast_local = await feast_dir.download()
    shutil.copy2(model_local, "model.joblib")
    if os.path.exists("feast_artifacts"):
        shutil.rmtree("feast_artifacts")
    shutil.copytree(feast_local, "feast_artifacts")
    log.info("Saved local copies: model.joblib, feast_artifacts/")

    html = (
        '<h2>Fraud Detection Pipeline</h2>'
        + rh.pipeline_step_indicator(4, steps)
        + '<div class="card">'
        '<div style="font-weight:600;color:#155724;font-size:1.1em;margin-bottom:8px;">Pipeline Complete</div>'
        '<p>Model and feature store artifacts are ready for serving.</p>'
        '<table>'
        '<tr><th>Next Step</th><th>Command</th></tr>'
        '<tr><td>Run locally</td><td><code>python app.py</code></td></tr>'
        '<tr><td>Deploy scoring app</td><td><code>flyte deploy app.py serving_env</code></td></tr>'
        '<tr><td>Deploy dashboard</td><td><code>flyte deploy dashboard.py dashboard_env</code></td></tr>'
        '</table></div>'
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    log.info("Pipeline complete")
    return model_file, feast_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(fraud_detection_pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/fraud_detection_feast/fraud_detection_feast.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "feast==0.63.0",
#    "xgboost==3.2.0",
#    "scikit-learn==1.8.0",
#    "kagglehub==0.3.12",
#    ...
# ]
# ///
```

## Orchestrate the pipeline

The `fraud_detection_pipeline` task downloads data, trains XGBoost, applies Feast feature definitions, and materializes features.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "feast==0.63.0",
#    "scikit-learn==1.8.0",
#    "xgboost==3.2.0",
#    "joblib",
#    "pandas",
#    "pyarrow",
#    "kagglehub==0.3.12",
# ]
# main = "fraud_detection_pipeline"
# params = ""
# ///
import json
import logging
import math
import os
import shutil
import tempfile
from datetime import datetime, timedelta, timezone

import joblib
import numpy as np
import pandas as pd
import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="fraud-detection-feast", pre=True)

env = flyte.TaskEnvironment(
    name="fraud-detection-feast",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)
# {{/docs-fragment env}}

import report_helpers as rh

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Feature definitions
#
# Transaction features: known at scoring time (from the request)
# User features: pre-computed aggregates stored in Feast
# Derived features: computed at both training and scoring time by
#                   comparing the transaction to the user's profile
# ------------------------------------------------------------------

TXN_FEATURE_COLS = ["amt", "amt_log", "category_encoded", "merch_lat", "merch_long"]

USER_FEATURE_COLS = [
    "txn_count", "mean_amt", "std_amt", "max_amt",
    "home_lat", "home_long", "age",
]

DERIVED_FEATURE_COLS = [
    "amt_zscore", "amt_ratio", "distance_from_home", "hour", "day_of_week",
]

ALL_FEATURE_COLS = TXN_FEATURE_COLS + USER_FEATURE_COLS + DERIVED_FEATURE_COLS

def haversine(lat1, lon1, lat2, lon2):
    """Compute distance in miles between two (lat, lon) points."""
    R = 3959  # Earth radius in miles
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
    return 2 * R * np.arcsin(np.sqrt(a))

# ------------------------------------------------------------------
# Task 1: Download dataset and engineer features
# ------------------------------------------------------------------

@env.task(report=True, cache="auto")
async def prepare_data() -> flyte.io.Dir:
    """Download the Sparkov credit card fraud dataset and prepare parquets."""
    import kagglehub

    log.info("Downloading dataset...")
    dataset_path = kagglehub.dataset_download("kartik2112/fraud-detection")
    csv_path = os.path.join(dataset_path, "fraudTrain.csv")
    df = pd.read_csv(csv_path)
    log.info(f"Loaded {len(df):,} transactions ({int(df['is_fraud'].sum()):,} fraudulent)")

    # Sample for workshop speed (stratified to preserve fraud ratio)
    if len(df) > 500_000:
        from sklearn.model_selection import train_test_split
        df, _ = train_test_split(df, train_size=500_000, stratify=df["is_fraud"], random_state=42)
        log.info(f"Sampled to {len(df):,} transactions")

    # ------------------------------------------------------------------
    # Parse timestamps
    # ------------------------------------------------------------------
    df["event_timestamp"] = pd.to_datetime(df["trans_date_trans_time"])
    df["event_timestamp"] = df["event_timestamp"].dt.tz_localize("UTC")
    df["hour"] = df["event_timestamp"].dt.hour
    df["day_of_week"] = df["event_timestamp"].dt.dayofweek

    # ------------------------------------------------------------------
    # Map cc_num → sequential user_id for clean API
    # ------------------------------------------------------------------
    cc_nums = df["cc_num"].unique()
    cc_to_user = {cc: i for i, cc in enumerate(sorted(cc_nums))}
    df["user_id"] = df["cc_num"].map(cc_to_user)

    # ------------------------------------------------------------------
    # Feature engineering
    # ------------------------------------------------------------------
    df["amt_log"] = np.log1p(df["amt"])

    # Label-encode merchant category
    categories = sorted(df["category"].unique())
    cat_to_int = {cat: i for i, cat in enumerate(categories)}
    df["category_encoded"] = df["category"].map(cat_to_int)

    # Compute age from dob
    df["dob"] = pd.to_datetime(df["dob"]).dt.tz_localize("UTC")
    ref_date = df["event_timestamp"].max()
    df["age"] = ((ref_date - df["dob"]).dt.days / 365.25).astype(int)

    # Distance between buyer and merchant
    df["distance"] = haversine(df["lat"], df["long"], df["merch_lat"], df["merch_long"])

    # ------------------------------------------------------------------
    # Build user aggregates
    # ------------------------------------------------------------------
    user_stats = df.groupby("user_id").agg(
        txn_count=("amt", "count"),
        mean_amt=("amt", "mean"),
        std_amt=("amt", "std"),
        max_amt=("amt", "max"),
        home_lat=("lat", "median"),
        home_long=("long", "median"),
        age=("age", "first"),
    ).reset_index()
    user_stats["std_amt"] = user_stats["std_amt"].fillna(0)
    # Use earliest timestamp so Feast point-in-time joins work for all transactions
    earliest_ts = df.groupby("user_id")["event_timestamp"].min().reset_index()
    user_stats = user_stats.merge(earliest_ts, on="user_id")

    # ------------------------------------------------------------------
    # Save to temp directory
    # ------------------------------------------------------------------
    data_dir = tempfile.mkdtemp()

    txn_cols = [
        "user_id", "event_timestamp",
        "amt", "amt_log", "category_encoded", "merch_lat", "merch_long",
        "hour", "day_of_week", "lat", "long", "distance",
        "is_fraud",
    ]
    df[txn_cols].to_parquet(os.path.join(data_dir, "transactions.parquet"), index=False)
    user_stats.to_parquet(os.path.join(data_dir, "user_features.parquet"), index=False)

    # Save category mapping + cc_num mapping for the app
    with open(os.path.join(data_dir, "category_mapping.json"), "w") as f:
        json.dump(cat_to_int, f)
    with open(os.path.join(data_dir, "user_mapping.json"), "w") as f:
        json.dump({str(k): v for k, v in cc_to_user.items()}, f)

    n_fraud = int(df["is_fraud"].sum())
    n_legit = len(df) - n_fraud
    fraud_pct = df["is_fraud"].mean() * 100
    html = (
        '<h2>Data Prepared</h2>'
        + rh.stat_grid([
            (f"{len(df):,}", "Transactions"),
            (f"{n_fraud:,}", "Fraudulent"),
            (f"{fraud_pct:.2f}%", "Fraud Rate"),
            (f"{user_stats['user_id'].nunique():,}", "Users"),
            (f"{len(categories)}", "Categories"),
        ])
        + rh.class_distribution_bar(n_legit, n_fraud)
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    return await flyte.io.Dir.from_local(data_dir)

# ------------------------------------------------------------------
# Task 2: Set up Feast and materialize user profiles to online store
# ------------------------------------------------------------------

@env.task(report=True)
async def materialize_features(data_dir: flyte.io.Dir) -> flyte.io.Dir:
    """Apply Feast definitions and materialize user profiles to SQLite online store."""
    from feast import Entity, FeatureStore, FeatureView, Field, FileSource
    from feast.types import Float64, Int64

    data_path = await data_dir.download()

    # Create a self-contained Feast repo in a temp directory
    feast_dir = tempfile.mkdtemp()

    # Copy parquet into feast dir so the repo is fully self-contained
    shutil.copy2(
        os.path.join(data_path, "user_features.parquet"),
        os.path.join(feast_dir, "user_features.parquet"),
    )

    # Write feature_store.yaml
    yaml_content = (
        "project: fraud_detection\n"
        f"registry: {feast_dir}/registry.db\n"
        "provider: local\n"
        "online_store:\n"
        "  type: sqlite\n"
        f"  path: {feast_dir}/online_store.db\n"
        "offline_store:\n"
        "  type: file\n"
        "entity_key_serialization_version: 3\n"
    )
    yaml_path = os.path.join(feast_dir, "feature_store.yaml")
    with open(yaml_path, "w") as f:
        f.write(yaml_content)

    store = FeatureStore(repo_path=feast_dir)

    # Define entity and feature view
    user = Entity(name="user", join_keys=["user_id"], description="Credit card holder")

    user_source = FileSource(
        path=os.path.join(feast_dir, "user_features.parquet"),
        timestamp_field="event_timestamp",
    )

    user_stats = FeatureView(
        name="user_stats",
        entities=[user],
        ttl=timedelta(days=0),  # No expiry — workshop data has old timestamps
        schema=[
            Field(name="txn_count", dtype=Int64),
            Field(name="mean_amt", dtype=Float64),
            Field(name="std_amt", dtype=Float64),
            Field(name="max_amt", dtype=Float64),
            Field(name="home_lat", dtype=Float64),
            Field(name="home_long", dtype=Float64),
            Field(name="age", dtype=Int64),
        ],
        online=True,
        source=user_source,
    )

    # Apply and materialize
    log.info("Applying Feast definitions...")
    store.apply([user, user_stats])

    log.info("Materializing user profiles to online store...")
    store.materialize(
        start_date=datetime(2018, 1, 1, tzinfo=timezone.utc),
        end_date=datetime.now(timezone.utc),
    )

    # Re-apply with relative paths so the registry is portable across workers
    portable_yaml = (
        "project: fraud_detection\n"
        "registry: registry.db\n"
        "provider: local\n"
        "online_store:\n"
        "  type: sqlite\n"
        "  path: online_store.db\n"
        "offline_store:\n"
        "  type: file\n"
        "entity_key_serialization_version: 3\n"
    )
    with open(yaml_path, "w") as f:
        f.write(portable_yaml)

    # Re-apply with relative source path so get_historical_features works on other workers
    store = FeatureStore(repo_path=feast_dir)
    user_source = FileSource(
        path="user_features.parquet",
        timestamp_field="event_timestamp",
    )
    user_stats = FeatureView(
        name="user_stats",
        entities=[user],
        ttl=timedelta(days=0),
        schema=[
            Field(name="txn_count", dtype=Int64),
            Field(name="mean_amt", dtype=Float64),
            Field(name="std_amt", dtype=Float64),
            Field(name="max_amt", dtype=Float64),
            Field(name="home_lat", dtype=Float64),
            Field(name="home_long", dtype=Float64),
            Field(name="age", dtype=Int64),
        ],
        online=True,
        source=user_source,
    )
    store.apply([user, user_stats])

    features = ["txn_count", "mean_amt", "std_amt", "max_amt", "home_lat", "home_long", "age"]
    html = (
        '<h2>Feature Store Materialized</h2>'
        + rh.stat_grid([
            ("user_stats", "Feature View"),
            (str(len(features)), "Features"),
            ("SQLite", "Online Store"),
        ])
        + '<h3>Materialized Features</h3>'
        '<table>'
        '<tr><th>Feature</th><th>Type</th><th>Description</th></tr>'
        '<tr><td>txn_count</td><td><span class="badge badge-info">Int64</span></td><td>Total transactions</td></tr>'
        '<tr><td>mean_amt</td><td><span class="badge badge-info">Float64</span></td><td>Average transaction amount</td></tr>'
        '<tr><td>std_amt</td><td><span class="badge badge-info">Float64</span></td><td>Std dev of amounts</td></tr>'
        '<tr><td>max_amt</td><td><span class="badge badge-info">Float64</span></td><td>Max transaction amount</td></tr>'
        '<tr><td>home_lat</td><td><span class="badge badge-info">Float64</span></td><td>Home latitude (median)</td></tr>'
        '<tr><td>home_long</td><td><span class="badge badge-info">Float64</span></td><td>Home longitude (median)</td></tr>'
        '<tr><td>age</td><td><span class="badge badge-info">Int64</span></td><td>User age</td></tr>'
        '</table>'
        '<div class="note">User profiles are ready for real-time serving via the scoring app.</div>'
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    return await flyte.io.Dir.from_local(feast_dir)

# ------------------------------------------------------------------
# Task 3: Train XGBoost model
# ------------------------------------------------------------------

@env.task(report=True)
async def train_model(
    data_dir: flyte.io.Dir,
    feast_dir: flyte.io.Dir,
    n_estimators: int = 300,
    max_depth: int = 6,
    learning_rate: float = 0.1,
    min_child_weight: int = 5,
    gamma: float = 1.0,
) -> flyte.io.File:
    """Train an XGBoost classifier using Feast for feature retrieval."""
    from feast import FeatureStore
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
    from xgboost import XGBClassifier

    data_path = await data_dir.download()
    feast_path = await feast_dir.download()
    txn_df = pd.read_parquet(os.path.join(data_path, "transactions.parquet"))

    with open(os.path.join(data_path, "category_mapping.json")) as f:
        category_mapping = json.load(f)

    # Fetch user features from Feast (same path as serving)
    store = FeatureStore(repo_path=feast_path)
    entity_df = txn_df[["user_id", "event_timestamp"]].copy()

    log.info("Fetching user features from Feast (get_historical_features)...")
    training_data = store.get_historical_features(
        entity_df=entity_df,
        features=[
            "user_stats:txn_count",
            "user_stats:mean_amt",
            "user_stats:std_amt",
            "user_stats:max_amt",
            "user_stats:home_lat",
            "user_stats:home_long",
            "user_stats:age",
        ],
    ).to_df()

    # Merge back transaction features (Feast only returns user profile)
    training_data = training_data.merge(
        txn_df[["user_id", "event_timestamp", "amt", "amt_log", "category_encoded",
                "merch_lat", "merch_long", "hour", "day_of_week", "is_fraud"]],
        on=["user_id", "event_timestamp"],
        how="inner",
    )

    # Derived features: compare this transaction to the user's profile
    training_data["amt_zscore"] = (
        (training_data["amt"] - training_data["mean_amt"])
        / training_data["std_amt"].replace(0, 1)
    )
    training_data["amt_ratio"] = (
        training_data["amt"] / training_data["mean_amt"].replace(0, 1)
    )
    training_data["distance_from_home"] = haversine(
        training_data["home_lat"], training_data["home_long"],
        training_data["merch_lat"], training_data["merch_long"],
    )

    training_data = training_data.dropna(subset=ALL_FEATURE_COLS)
    X = training_data[ALL_FEATURE_COLS].values
    y = training_data["is_fraud"].values
    log.info(f"Training on {len(X):,} rows, {int(y.sum()):,} fraud")

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y,
    )

    n_legit = int((y_train == 0).sum())
    n_fraud = int((y_train == 1).sum())
    scale_pos_weight = n_legit / max(n_fraud, 1)

    model = XGBClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        scale_pos_weight=scale_pos_weight,
        min_child_weight=min_child_weight,
        gamma=gamma,
        random_state=42,
        eval_metric="logloss",
    )
    model.fit(X_train, y_train)

    # Evaluate
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1]
    auc = roc_auc_score(y_test, y_proba)
    cm = confusion_matrix(y_test, y_pred)
    report = classification_report(y_test, y_pred, target_names=["Legit", "Fraud"])

    log.info(f"AUC-ROC: {auc:.4f}")
    log.info(f"\n{report}")

    # Report
    precision_fraud = cm[1][1] / max(cm[1][1] + cm[0][1], 1) * 100
    recall_fraud = cm[1][1] / max(cm[1][1] + cm[1][0], 1) * 100

    html = (
        '<h2>Model Performance</h2>'
        + rh.stat_grid([
            (f"{auc:.4f}", "AUC-ROC"),
            (f"{len(X_train):,}", "Training Samples"),
            (f"{len(X_test):,}", "Test Samples"),
            (f"{precision_fraud:.1f}%", "Fraud Precision"),
            (f"{recall_fraud:.1f}%", "Fraud Recall"),
        ])
        + rh.confusion_matrix_html(cm)
    )

    # Feature importance bar chart
    importance = model.feature_importances_
    top_idx = np.argsort(importance)[::-1]
    top_labels = [ALL_FEATURE_COLS[i] for i in top_idx]
    top_values = [float(importance[i]) for i in top_idx]
    html += '<h3>Feature Importance</h3>'
    html += f'<div class="card">{rh.horizontal_bar_chart(top_labels, top_values)}</div>'

    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Save model + metadata
    model_path = os.path.join(tempfile.mkdtemp(), "model.joblib")
    joblib.dump({
        "model": model,
        "auc_roc": auc,
        "feature_cols": ALL_FEATURE_COLS,
        "category_mapping": category_mapping,
    }, model_path)

    return await flyte.io.File.from_local(model_path)

# ------------------------------------------------------------------
# Orchestrator: prepare → materialize → train
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def fraud_detection_pipeline(
    n_estimators: int = 300,
    max_depth: int = 6,
    learning_rate: float = 0.1,
    min_child_weight: int = 5,
    gamma: float = 1.0,
) -> tuple[flyte.io.File, flyte.io.Dir]:
    """
    Full fraud detection pipeline:
    1. Download and prepare data
    2. Materialize user profiles to Feast
    3. Train model using Feast for feature retrieval
    Returns model file and Feast artifacts for serving.
    """
    log.info("Starting fraud detection pipeline")
    steps = ["Prepare Data", "Materialize Features", "Train Model", "Done"]

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(0, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    data_dir = await prepare_data()

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(1, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Materialize features first so training can use Feast
    feast_dir = await materialize_features(data_dir)

    html = '<h2>Fraud Detection Pipeline</h2>' + rh.pipeline_step_indicator(2, steps)
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    # Train model using Feast for user feature retrieval
    model_file = await train_model(
        data_dir,
        feast_dir,
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        min_child_weight=min_child_weight,
        gamma=gamma,
    )

    # Save copies to working directory for local app testing
    model_local = await model_file.download()
    feast_local = await feast_dir.download()
    shutil.copy2(model_local, "model.joblib")
    if os.path.exists("feast_artifacts"):
        shutil.rmtree("feast_artifacts")
    shutil.copytree(feast_local, "feast_artifacts")
    log.info("Saved local copies: model.joblib, feast_artifacts/")

    html = (
        '<h2>Fraud Detection Pipeline</h2>'
        + rh.pipeline_step_indicator(4, steps)
        + '<div class="card">'
        '<div style="font-weight:600;color:#155724;font-size:1.1em;margin-bottom:8px;">Pipeline Complete</div>'
        '<p>Model and feature store artifacts are ready for serving.</p>'
        '<table>'
        '<tr><th>Next Step</th><th>Command</th></tr>'
        '<tr><td>Run locally</td><td><code>python app.py</code></td></tr>'
        '<tr><td>Deploy scoring app</td><td><code>flyte deploy app.py serving_env</code></td></tr>'
        '<tr><td>Deploy dashboard</td><td><code>flyte deploy dashboard.py dashboard_env</code></td></tr>'
        '</table></div>'
    )
    await flyte.report.replace.aio(rh.wrap(html))
    await flyte.report.flush.aio()

    log.info("Pipeline complete")
    return model_file, feast_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(fraud_detection_pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/fraud_detection_feast/fraud_detection_feast.py*

## Run the workflow

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/fraud_detection_feast):

```
cd v2/tutorials/fraud_detection_feast
uv run --script fraud_detection_feast.py
```

The first run downloads the dataset via `kagglehub` (public dataset, no API key required). Open the run report to review the confusion matrix and feature-importance summary when training completes.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/financial-services/financial-research-agent ===

# Financial research agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/financial_research_agent).

This example demonstrates how to build a financial research and earnings-cycle agent on Flyte. For each company, the agent runs grounded, source-cited research and fresh news, then synthesizes an analyst-ready equity briefing.

Financial research benefits from **low-latency, ranked, source-cited results** across both the general web and news streams. The [You.com Research API](https://you.com/docs/research/overview) produces a grounded, citation-backed synthesis, and the [You.com Search API](https://you.com/docs/search/overview) adds a fresh-news layer. [Claude](https://docs.anthropic.com/) via [LiteLLM](https://docs.litellm.ai/) turns that evidence into an analyst-ready briefing. Flyte's `cache="auto"` reuses prior results when runs converge on the same companies.

Flyte provides:

- **Fan-out parallelism** across companies
- **`cache="auto"`** to reuse prior You.com and LLM results across converging runs
- **`@flyte.trace`** on every external call for full prompt → citation lineage
- **Flyte reports** with thesis, risks, watch items, and source citations per company

![Financial research agent report](https://www.union.ai/docs/v2/union/_static/images/tutorials/financial_research_agent/financial-research-agent.png)

## Setting up the environment

The agent runs in a `TaskEnvironment` with secrets for the You.com and Anthropic API keys, automatic caching, and a container image built from the `uv` script dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# ///
```

## Data types

Each `Briefing` carries a thesis, recent developments, risks, watch items, and a list of `Source` objects from both the Research and Search APIs.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

## You.com Research and Search APIs

The agent uses both You.com APIs in parallel for each company:

- **Research API** (`https://api.you.com/v1/research`) — grounded, citation-backed analysis with configurable `research_effort` (`lite`, `standard`, `deep`, `exhaustive`). See the [Research API reference](https://you.com/docs/api-reference/research/v1-research).
- **Search API** (`https://ydc-index.io/v1/search`) — fresh news headlines with `freshness` filtering. See the [Search API reference](https://you.com/docs/api-reference/search/v1-search).

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

## Synthesize briefings with Claude

Claude, routed through LiteLLM, turns the grounded research answer and news headlines into a structured equity briefing grounded in the evidence provided.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

## Research one company

The `research_company` task calls both You.com APIs in parallel, collects sources, and synthesizes a structured briefing.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

## Orchestration

The `financial_research` driver task fans out across all companies and renders a Flyte report with per-company briefings and citations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "financial_research"
# params = ""
# ///
"""Financial research & earnings-cycle agent.

For each company, runs grounded, source-cited research via the You.com Research
API plus a fresh-news layer via the Search API, then uses Claude to synthesize
an analyst-ready equity briefing that preserves citations. Flyte caching cuts
duplicate spend when runs converge.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="financial-research",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="financial-research", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    domain: str = ""
    snippet: str = ""
    published: str = ""
    favicon: str = ""
    section: str = "research"  # "research", "news", or "web"

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Briefing:
    company: str
    thesis: str
    recent_developments: list[str] = field(default_factory=list)
    risks: list[str] = field(default_factory=list)
    watch_items: list[str] = field(default_factory=list)
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResearchReport:
    briefings: list[Briefing] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_apis}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_request(method: str, url: str, timeout: float, **kwargs) -> dict:
    """HTTP wrapper with exponential backoff + jitter on 429 rate limits.

    Fanned-out tasks run in separate pods, so we retry on the client side to
    smooth out bursts against the You.com API rate limit.
    """
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    if method == "POST":
        headers["Content-Type"] = "application/json"

    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.request(method, url, headers=headers, **kwargs)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str, freshness: str) -> dict:
    """Grounded, citation-backed research answer."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {"freshness": freshness},
    }
    return await _you_request("POST", YOU_RESEARCH_URL, 300.0, json=body)

@flyte.trace
async def you_news(query: str, count: int = 6, freshness: str = "week") -> list[dict]:
    """Fresh news headlines for a company."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_request("GET", YOU_SEARCH_URL, 60.0, params=params)

    results = data.get("results", {})
    out: list[dict] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            out.append(
                {
                    "title": item.get("title", ""),
                    "url": url,
                    "domain": _domain(url),
                    "snippet": snippets[0] if snippets else item.get("description", ""),
                    "published": item.get("page_age", "") or "",
                    "favicon": item.get("favicon_url")
                    or _favicon_for(url),
                    "section": section,
                }
            )
    return out
# {{/docs-fragment you_apis}}

# {{docs-fragment llm}}
@flyte.trace
async def synthesize_briefing(company: str, focus: str, research: str, news: str) -> dict:
    """Use Claude to synthesize a structured equity briefing."""
    from litellm import acompletion

    system = (
        "You are an equity research analyst. Using ONLY the grounded research "
        "and news provided, write a concise briefing. Respond ONLY with JSON: "
        '{"thesis": str, "recent_developments": [str], "risks": [str], '
        '"watch_items": [str]}. Keep each list to 3-5 short, specific bullets.'
    )
    user = (
        f"Company: {company}\nFocus: {focus}\n\n"
        f"Grounded research:\n{research}\n\nRecent news:\n{news}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment research_company}}
@env.task(retries=3)
async def research_company(
    company: str,
    focus: str,
    research_effort: str,
    freshness: str,
) -> Briefing:
    """Research one company and synthesize a cited briefing."""
    question = (
        f"Provide a grounded analysis of {company} with respect to: {focus}. "
        f"Cover recent financial performance, strategic moves, competitive "
        f"positioning, and risks."
    )
    research_result, news = await asyncio.gather(
        you_research(question, research_effort, freshness),
        you_news(f"{company} earnings news", freshness=freshness),
    )

    output = research_result.get("output", {})
    research_text = output.get("content", "")
    if not isinstance(research_text, str):
        research_text = json.dumps(research_text)

    sources: list[Source] = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                domain=_domain(url),
                snippet=str((s.get("snippets") or [""])[0]),
                favicon=_favicon_for(url),
                section="research",
            )
        )
    for n in news:
        sources.append(
            Source(
                title=str(n.get("title", "")),
                url=str(n.get("url", "")),
                domain=str(n.get("domain", "")),
                snippet=str(n.get("snippet", "")),
                published=str(n.get("published", "")),
                favicon=str(n.get("favicon", "")),
                section=str(n.get("section", "web")),
            )
        )
    news_text = "\n".join(
        f"- {n['title']} ({n['published']}) {n['domain']}: {n['snippet'][:120]}"
        for n in news
    )

    parsed = await synthesize_briefing(company, focus, research_text, news_text)

    def _list(key: str) -> list[str]:
        return [str(x) for x in (parsed.get(key) or [])]

    return Briefing(
        company=company,
        thesis=str(parsed.get("thesis", "")),
        recent_developments=_list("recent_developments"),
        risks=_list("risks"),
        watch_items=_list("watch_items"),
        sources=sources,
    )
# {{/docs-fragment research_company}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:18px; margin:0 0 8px; color:#102a43; }
  .rpt .thesis { font-size:14px; line-height:1.5; background:#f7f9fb;
                 border-radius:8px; padding:10px 12px; margin:0 0 14px; }
  .rpt .cols { display:flex; gap:18px; flex-wrap:wrap; }
  .rpt .col { flex:1; min-width:220px; }
  .rpt .col h3 { font-size:12px; text-transform:uppercase; letter-spacing:.04em;
                 color:#627d98; margin:0 0 6px; }
  .rpt .col.risks h3 { color:#c0392b; }
  .rpt ul { margin:0; padding-left:18px; }
  .rpt li { font-size:13px; line-height:1.5; margin:0 0 4px; }
  .rpt .sources { margin-top:14px; border-top:1px solid #f0f2f5; padding-top:10px; }
  .rpt .sources h3 { font-size:12px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.research { background:#5b8def; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com citation (Research or Search source)."""
    if not s.url:
        return ""
    tag_cls = s.section if s.section in ("research", "news") else "web"
    meta_bits = []
    if s.published:
        meta_bits.append(s.published[:10])
    if s.title:
        meta_bits.append(s.title)
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag {tag_cls}'>{s.section}</span>"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: ResearchReport) -> str:
    def _ul(items: list[str]) -> str:
        if not items:
            return "<p class='empty'>None reported.</p>"
        return "<ul>" + "".join(f"<li>{x}</li>" for x in items) + "</ul>"

    cards = []
    for b in report.briefings:
        src = "".join(_cite(s) for s in b.sources[:10])
        cards.append(
            f"<div class='card'><h2>{b.company}</h2>"
            f"<div class='thesis'>{b.thesis or 'No thesis generated.'}</div>"
            f"<div class='cols'>"
            f"<div class='col'><h3>Recent developments</h3>{_ul(b.recent_developments)}</div>"
            f"<div class='col risks'><h3>Risks</h3>{_ul(b.risks)}</div>"
            f"<div class='col'><h3>Watch items</h3>{_ul(b.watch_items)}</div>"
            f"</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(b.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(b.sources) for b in report.briefings)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Financial Research Briefings</h1>
      <p class="sub">Grounded, citation-backed equity briefings — each company
      backed by You.com Research synthesis plus fresh Search news.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.briefings)}</b> companies</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No briefings generated.</p>"}
      <p class="yoube">Research answers from the You.com Research API (grounded
      synthesis with inline citations) plus fresh headlines from the You.com
      Search API (web + auto-classified news with timestamps and snippets).</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def financial_research(
    companies: list[str] = [
        "NVIDIA",
        "Advanced Micro Devices",
        "Microsoft",
        "Alphabet",
        "Amazon",
        "Meta Platforms",
        "Broadcom",
        "Taiwan Semiconductor Manufacturing",
    ],
    focus: str = "Q4 earnings preview and competitive positioning",
    research_effort: str = "standard",
    freshness: str = "month",
) -> ResearchReport:
    """Fan out across companies and aggregate cited equity briefings."""
    with flyte.group("research-companies"):
        briefings = await asyncio.gather(
            *[
                research_company(c, focus, research_effort, freshness)
                for c in companies
            ]
        )

    report = ResearchReport(briefings=list(briefings))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(financial_research)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/financial_research_agent/main.py*

## Run the agent

### Create secrets

Get a You.com API key from the [You.com platform](https://you.com/platform) (see the [quickstart guide](https://you.com/docs/quickstart)). Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/).

Register both keys as Flyte secrets. The secret key names must match those declared in the `TaskEnvironment`:

```
flyte create secret youdotcom-api-key <YOUR_YOU_API_KEY>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run locally or remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/financial_research_agent):

```
cd v2/tutorials/financial_research_agent
uv run --script main.py
```

To test locally without Flyte secrets:

```
export YOU_API_KEY=<YOUR_YOU_API_KEY>
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>

uv run --script main.py
```

When the run completes, open the Flyte report to review equity briefings with thesis, risks, and You.com source citations for each company.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/frontier-ai ===

# Frontier AI

Tutorials for frontier-model pretraining, automated experimentation, and large-scale AI workloads.

### **Frontier AI > Distributed LLM pretraining**

Pretrain large language models at scale with PyTorch Lightning, FSDP, and H200 GPUs, featuring streaming data and real-time metrics.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/frontier-ai/distributed-pretraining ===

# Distributed LLM pretraining

When training large models, infrastructure should not be the hardest part. The real work is in the model architecture, the data, and the hyperparameters. In practice, though, teams often spend weeks just trying to get distributed training to run reliably.

And when it breaks, it usually breaks in familiar ways: out-of-memory crashes, corrupted checkpoints, data loaders that silently fail, or runs that hang with no obvious explanation.

Most distributed training tutorials focus on PyTorch primitives. This one focuses on getting something that actually ships. We go into the technical details, such as how FSDP shards parameters, why gradient clipping behaves differently at scale, and how streaming datasets reduce memory pressure, but always with the goal of building a system that works in production.

Real training jobs need more than a training loop. They need checkpointing, fault tolerance, data streaming, visibility into what’s happening, and the ability to recover from failures. In this tutorial, we build all of that using Flyte, without having to stand up or manage any additional infrastructure.

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/pretraining/train.py).

## Overview

We're going to pretrain a GPT-2 style language model from scratch. This involves training on raw text data starting from randomly initialized weights, rather than fine-tuning or adapting a pretrained model. This is the same process used to train the original GPT-2, LLaMA, and most other foundation models.

The model learns by predicting the next token. Given "The cat sat on the", it learns to predict "mat". Do this billions of times across terabytes of text, and the model develops surprisingly sophisticated language understanding. That's pretraining.

The challenge is scale. A 30B parameter model doesn't fit on a single GPU. The training dataset, [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) in our case, is 627 billion tokens. Training runs last for days or even weeks. To make this work, you need:

- **Distributed training**: Split the model across multiple GPUs using [FSDP (Fully Sharded Data Parallel)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- **Data streaming**: Pull training data on-demand instead of downloading terabytes upfront
- **Checkpointing**: Save progress regularly so a failure doesn’t wipe out days of compute
- **Observability**: See what's happening inside a multi-day training run

We’ll build a Flyte pipeline that takes care of all of this, using three tasks with clearly defined responsibilities:

1. **Data preparation**: Tokenizes your dataset and converts it to MDS (MosaicML Data Shard) format for streaming. This Flyte task is cached, so it only needs to be run once and can be reused across runs.
2. **Distributed training**: Runs FSDP across 8 H200 GPUs. Flyte's `Elastic` plugin handles the distributed setup. Checkpoints upload to S3 automatically via Flyte's `File` abstraction.
3. **Real-time reporting**: Streams loss curves and training metrics to Flyte Reports, a live dashboard integrated into the Flyte UI.

Why three separate tasks? Flyte makes this separation efficient:

- **Caching**: The data preparation step runs once. On subsequent runs, Flyte skips it entirely.
- **Resource isolation**: Training uses expensive H200 GPUs only while actively training, while the driver runs on inexpensive CPU instances.
- **Fault boundaries**: If training fails, the data preparation step does not re-run. Training can resume directly from the most recent checkpoint.

## Implementation

Let's walk through the code. We'll start with the infrastructure setup, build the model, then wire everything together into a pipeline.

### Setting up the environment

Every distributed training job needs a consistent environment across all nodes. Flyte handles this with container images:

```
import logging
import math
import os
from pathlib import Path
from typing import Optional

import flyte
import flyte.report
import lightning as L
import numpy as np
import torch
import torch.nn as nn
from flyte.io import Dir, File
from flyteplugins.pytorch.task import Elastic
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The imports tell the story: `flyte` for orchestration, `flyte.report` for live dashboards, `lightning` for training loop management, and `Elastic` from Flyte's PyTorch plugin. This last one is key as it configures PyTorch's distributed launch without you writing any distributed setup code.

```
NUM_NODES = 1
DEVICES_PER_NODE = 8
VOCAB_SIZE = (
    50257  # GPT-2 BPE tokenizer vocabulary size (constant across all model sizes)
)
N_POSITIONS = 2048  # Maximum sequence length (constant across all model sizes)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

These constants define the distributed topology. We're using 1 node with 8 GPUs, but you can scale this up by changing `NUM_NODES`. The vocabulary size (50,257 tokens) and sequence length (2,048 tokens) match GPT-2's [Byte Pair Encoding (BPE) tokenizer](https://huggingface.co/learn/llm-course/en/chapter6/5).

```
image = flyte.Image.from_debian_base(
    name="distributed_training_h200"
).with_pip_packages(
    "transformers==4.57.3",
    "datasets==4.4.1",
    "tokenizers==0.22.1",
    "huggingface-hub==0.34.0",
    "mosaicml-streaming>=0.7.0",
    "pyarrow==22.0.0",
    "flyteplugins-pytorch>=2.0.0b33",
    "torch==2.9.1",
    "lightning==2.5.6",
    "tensorboard==2.20.0",
    "sentencepiece==0.2.1",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Flyte builds this container automatically when the pipeline is run. All dependencies required for distributed training, including PyTorch, Lightning, the streaming library, and NCCL for GPU communication, are baked in. There's no Dockerfile to maintain and no "works on my machine" debugging.

### Declaring resource requirements

Different parts of the pipeline need different resources. Data tokenization needs CPU and memory. Training needs GPUs. The driver just coordinates. Flyte's `TaskEnvironment` lets you declare exactly what each task needs:

```
data_loading_env = flyte.TaskEnvironment(
    name="data_loading_h200",
    image=image,
    resources=flyte.Resources(cpu=5, memory="28Gi", disk="100Gi"),
    env_vars={
        "HF_DATASETS_CACHE": "/tmp/hf_cache",  # Cache directory for datasets
        "TOKENIZERS_PARALLELISM": "true",  # Enable parallel tokenization
    },
    cache="auto",
)

distributed_llm_training_env = flyte.TaskEnvironment(
    name="distributed_llm_training_h200",
    image=image,
    resources=flyte.Resources(
        cpu=64,
        memory="512Gi",
        gpu=f"H200:{DEVICES_PER_NODE}",
        disk="1Ti",
        shm="16Gi",  # Explicit shared memory for NCCL communication
    ),
    plugin_config=Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE),
    env_vars={
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
        "NCCL_DEBUG": "WARN",
    },
    cache="auto",
)

driver_env = flyte.TaskEnvironment(
    name="llm_training_driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    cache="auto",
    depends_on=[data_loading_env, distributed_llm_training_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Let's break down the training environment, since this is where most of the complexity lives:

- **`gpu=f"H200:{DEVICES_PER_NODE}"`**: Flyte provisions exactly 8 H200 GPUs. These have 141GB of memory each, enough to train 30B+ parameter models with FSDP.
- **`shm="16Gi"`**: This allocates explicit shared memory. NCCL (NVIDIA's communication library) uses shared memory for inter-GPU communication on the same node. Without this, you'll see cryptic errors like "NCCL error: unhandled system error", which can be difficult to debug.
- **`Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE)`**: This is Flyte's integration with PyTorch's elastic launch. It handles process spawning (one process per GPU), rank assignment (each process knows its ID), and environment setup (master address, world size). This replaces the boilerplate typically written in shell scripts.

The `driver_env` is intentionally lightweight, using 2 CPUs and 4 GB of memory. Its role is limited to orchestrating tasks and passing data between them, so allocating GPUs here would be unnecessary.

### Model configurations

Training a 1.5B model uses different hyperparameters than training a 65B model. Rather than hardcoding values, we define presets:

```
MODEL_CONFIGS = {
    "1.5B": {
        "n_embd": 2048,
        "n_layer": 24,
        "n_head": 16,
        "batch_size": 8,
        "learning_rate": 6e-4,
        "checkpoint_every_n_steps": 10,
        "report_every_n_steps": 5,
        "val_check_interval": 100,
    },  # Good for testing and debugging
    "30B": {
        "n_embd": 6656,
        "n_layer": 48,
        "n_head": 52,
        "batch_size": 1,
        "learning_rate": 1.6e-4,
        "checkpoint_every_n_steps": 7500,
        "report_every_n_steps": 200,
        "val_check_interval": 1000,
    },
    "65B": {
        "n_embd": 8192,
        "n_layer": 80,
        "n_head": 64,
        "batch_size": 1,
        "learning_rate": 1.5e-4,
        "checkpoint_every_n_steps": 10000,
        "report_every_n_steps": 250,
        "val_check_interval": 2000,
    },
}

def get_model_config(model_size: str) -> dict:
    if model_size not in MODEL_CONFIGS:
        available = ", ".join(MODEL_CONFIGS.keys())
        raise ValueError(f"Unknown model size: {model_size}. Available: {available}")

    return MODEL_CONFIGS[model_size]
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

A few things to notice:

- **Batch size decreases with model size**: For a fixed GPU memory budget, larger models consume more memory for parameters, optimizer state, and activations, leaving less room for per-GPU batch size. For example, a 1.5B parameter model may fit a batch size of 8 per GPU, while a 65B model may only fit a batch size of 1. This is typically compensated for using gradient accumulation to maintain a larger effective batch size.
- **Learning rate decreases with model size**: Larger models are more sensitive to optimization instability and typically require lower learning rates. The values here follow empirical best practices used in large-scale language model training, informed by work such as the [Chinchilla study](https://arxiv.org/pdf/2203.15556) on compute-optimal scaling.
- **Checkpoint frequency increases with model size**: Checkpointing a 65B model is expensive (the checkpoint is huge). We do it less often but make sure we don't lose too much progress if something fails.

The 1.5B config is good for testing your setup before committing to a serious training run.

### Building the GPT model

Now for the model itself. We're building a GPT-2 style decoder-only transformer from scratch.

First, the configuration class:

```
class GPTConfig:
    """Configuration for GPT model."""

    def __init__(
        self,
        vocab_size: int = VOCAB_SIZE,
        n_positions: int = N_POSITIONS,
        n_embd: int = 2048,
        n_layer: int = 24,
        n_head: int = 16,
        n_inner: Optional[int] = None,
        activation_function: str = "gelu_new",
        dropout: float = 0.1,
        layer_norm_epsilon: float = 1e-5,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner if n_inner is not None else 4 * n_embd
        self.activation_function = activation_function
        self.dropout = dropout
        self.layer_norm_epsilon = layer_norm_epsilon
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The key architectural parameters:

- **`n_embd`**: The hidden (embedding) dimension. Larger values increase model capacity but also increase memory and compute requirements.
- **`n_layer`**: The number of transformer blocks. Model depth strongly influences expressiveness and performance.
- **`n_head`**: The number of attention heads. Each head can attend to different patterns or relationships in the input.
- **`n_inner`**: The hidden dimension of the feed-forward network (MLP), typically set to 4x the embedding dimension.

Next, we define a single transformer block:

```
class GPTBlock(nn.Module):
    """Transformer block with causal self-attention."""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = nn.MultiheadAttention(
            config.n_embd,
            config.n_head,
            dropout=config.dropout,
            batch_first=True,
        )
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        # Get activation function from config
        ACT_FNS = {
            "gelu": nn.GELU(),
            "gelu_new": nn.GELU(approximate="tanh"),  # GPT-2 uses approximate GELU
            "relu": nn.ReLU(),
            "silu": nn.SiLU(),
            "swish": nn.SiLU(),  # SiLU = Swish
        }
        act_fn = ACT_FNS.get(config.activation_function, nn.GELU())

        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_inner),
            act_fn,
            nn.Linear(config.n_inner, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x, causal_mask, key_padding_mask=None):
        x_normed = self.ln_1(x)

        # Self-attention with causal and padding masks
        attn_output, _ = self.attn(
            x_normed,  # query
            x_normed,  # key
            x_normed,  # value
            attn_mask=causal_mask,  # Causal mask: (seq_len, seq_len)
            key_padding_mask=key_padding_mask,  # Padding mask: (batch, seq_len)
            need_weights=False,
        )
        x = x + attn_output

        # MLP with residual
        x = x + self.mlp(self.ln_2(x))
        return x
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Each block has two sub-layers: causal self-attention and a feed-forward MLP. The causal mask ensures the model can only attend to previous tokens in the sequence, so it can't "cheat" by looking at the answer. This is what makes it *autoregressive*.

The full `GPTModel` class (see the complete code) stacks these blocks and adds token and positional embeddings. One important detail is that the input token embedding matrix is shared with the output projection layer (often called [weight tying](https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers)). This reduces the number of parameters by roughly 50 million for typical vocabulary sizes and often leads to better generalization and more stable training.

### The Lightning training module

PyTorch Lightning handles the training loop boilerplate. We wrap our model in a `LightningModule` that defines how to train it:

```
class GPTPreTrainingModule(L.LightningModule):
    """PyTorch Lightning module for GPT pre-training."""

    def __init__(
        self,
        vocab_size: int = 50257,
        n_positions: int = 2048,
        n_embd: int = 2048,
        n_layer: int = 24,
        n_head: int = 16,
        learning_rate: float = 6e-4,
        weight_decay: float = 0.1,
        warmup_steps: int = 2000,
        max_steps: int = 100000,
    ):
        super().__init__()
        self.save_hyperparameters()

        config = GPTConfig(
            vocab_size=vocab_size,
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
        )
        self.model = GPTModel(config)

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The `save_hyperparameters()` call is important because it stores all constructor arguments in the checkpoint. This allows the model to be reloaded later without having to manually reconstruct the original configuration.

The training and validation steps implement standard causal language modeling, where the model is trained to predict the next token given all previous tokens in the sequence.

```
    def training_step(self, batch, _batch_idx):
        # Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
        input_ids = batch["input_ids"].long()
        labels = batch["labels"].long()

        # Get attention mask if present (optional, for padded sequences)
        # attention_mask: 1 = real token, 0 = padding
        # Note: Current data pipeline creates fixed-length sequences without padding,
        # so attention_mask is not present. If using padded sequences, ensure:
        #   - Padded positions in labels are set to -100 (ignored by cross_entropy)
        #   - attention_mask marks real tokens (1) vs padding (0)
        attention_mask = batch.get("attention_mask", None)

        # Forward pass (causal mask is created internally in GPTModel)
        logits = self(input_ids, attention_mask=attention_mask)

        # Shift logits and labels for causal language modeling
        # Before shift: labels[i] = input_ids[i]
        # After shift: predict input_ids[i+1] from input_ids[:i+1]
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Calculate loss
        loss = nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
        )

        # Log loss
        self.log(
            "train/loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        # Calculate and log perplexity only on epoch (exp is costly, less frequent is fine)
        perplexity = torch.exp(torch.clamp(loss, max=20.0))
        self.log(
            "train/perplexity",
            perplexity,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        return loss

    def validation_step(self, batch, _batch_idx):
        # Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
        input_ids = batch["input_ids"].long()
        labels = batch["labels"].long()

        # Get attention mask if present (optional, for padded sequences)
        attention_mask = batch.get("attention_mask", None)

        # Forward pass (causal mask is created internally in GPTModel)
        logits = self(input_ids, attention_mask=attention_mask)

        # Shift logits and labels
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Calculate loss
        loss = nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
        )

        # Log loss
        self.log("val/loss", loss, prog_bar=True, sync_dist=True)

        # Calculate and log perplexity (exp is costly, but validation is infrequent so OK)
        perplexity = torch.exp(torch.clamp(loss, max=20.0))
        self.log("val/perplexity", perplexity, prog_bar=True, sync_dist=True)

        return loss
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The model performs a forward pass with a causal (autoregressive) mask created internally, ensuring each token can only attend to earlier positions. To align predictions with targets, the logits and labels are shifted so that the representation at position `i` is used to predict token `i + 1`.

Loss is computed using cross-entropy over the shifted logits and labels. Training loss and perplexity are logged during execution, with metrics synchronized across distributed workers.

The optimizer setup is where a lot of training stability comes from:

```
    def configure_optimizers(self):
        # Separate parameters into weight decay and no weight decay groups
        decay_params = []
        no_decay_params = []

        for param in self.model.parameters():
            if param.requires_grad:
                # 1D parameters (biases, LayerNorm) don't get weight decay
                # 2D+ parameters (weight matrices) get weight decay
                if param.ndim == 1:
                    no_decay_params.append(param)
                else:
                    decay_params.append(param)

        optimizer_grouped_parameters = [
            {"params": decay_params, "weight_decay": self.hparams.weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0},
        ]

        # AdamW optimizer
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            betas=(0.9, 0.95),
            eps=1e-8,
        )

        # Learning rate scheduler: warmup + cosine decay
        # Warmup: linear increase from 0 to 1.0 over warmup_steps
        # Decay: cosine decay from 1.0 to 0.0 over remaining steps
        def lr_lambda(current_step):
            if current_step < self.hparams.warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, self.hparams.warmup_steps))

            # Cosine decay after warmup
            progress = (current_step - self.hparams.warmup_steps) / max(
                1, self.hparams.max_steps - self.hparams.warmup_steps
            )
            # Cosine annealing from 1.0 to 0.0 (returns float, not tensor)
            return 0.5 * (1.0 + math.cos(progress * math.pi))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Two important choices here:

1. **Separate weight decay groups**: We only apply weight decay to the weight matrices, not to biases or LayerNorm parameters. This follows the original BERT paper and is now standard practice, as regularizing biases and normalization parameters does not improve performance and can be harmful.
2. **Cosine learning rate schedule with warmup**: We start with a low learning rate, ramp up linearly during warmup (helps stabilize early training when gradients are noisy), then decay following a cosine curve. This schedule outperforms constant or step decay for transformer training.

### Checkpointing for fault tolerance

Training a 30B-parameter model for 15,000 steps can take days. Hardware failures and spot instance preemptions are inevitable, which makes checkpointing essential.

```
class S3CheckpointCallback(L.Callback):
    """
    Periodically upload checkpoints to S3 for durability and resumption.

    This ensures checkpoints are safely stored in remote storage even if
    the training job is interrupted or the instance fails.
    """

    def __init__(self, checkpoint_dir: Path, upload_every_n_steps: int):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.upload_every_n_steps = upload_every_n_steps
        self.last_uploaded_step = -1

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        """Upload checkpoint to S3 every N steps."""
        if trainer.global_rank != 0:
            return  # Only upload from rank 0

        current_step = trainer.global_step

        # Upload every N steps (aligns with ModelCheckpoint's every_n_train_steps)
        if (
            current_step % self.upload_every_n_steps == 0
            and current_step > self.last_uploaded_step
            and current_step > 0
        ):
            try:
                # Find the most recent checkpoint file
                checkpoint_files = list(self.checkpoint_dir.glob("*.ckpt"))
                if not checkpoint_files:
                    print("No checkpoint files found to upload")
                    return

                # Get the latest checkpoint (by modification time)
                latest_checkpoint = max(
                    checkpoint_files, key=lambda p: p.stat().st_mtime
                )

                # Upload the checkpoint file directly to S3 using File.from_local_sync
                checkpoint_file = File.from_local_sync(str(latest_checkpoint))
                print(f"Checkpoint uploaded to S3 at: {checkpoint_file.path}")

                self.last_uploaded_step = current_step
            except Exception as e:
                print(f"Warning: Failed to upload checkpoint to S3: {e}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This callback runs every `N` training steps and uploads the checkpoint to durable storage. The key line is `File.from_local_sync()` which is a Flyte abstraction for uploading files. There are no blob store credentials to manage and no bucket paths to hardcode. Flyte automatically uses the storage backend configured for your cluster.

The callback only runs on rank 0. In distributed training, all 8 GPUs have identical model states (that's the point of data parallelism). Having all of them upload the same checkpoint would be wasteful and could cause race conditions.

When you restart a failed run, pass the checkpoint via `resume_checkpoint` so training resumes exactly where it left off, including the same step count, optimizer state, and learning rate schedule position.

### Real-time metrics with Flyte reports

Multi-day training runs need observability. Is the loss decreasing? Did training diverge? Is the learning rate schedule behaving correctly? Flyte Reports let you build live dashboards directly in the UI:

```
class FlyteReportingCallback(L.Callback):
    """Custom Lightning callback to report training metrics to Flyte Report."""

    def __init__(self, report_every_n_steps: int = 100):
        super().__init__()
        self.report_every_n_steps = report_every_n_steps
        self.metrics_history = {
            "step": [],
            "train_loss": [],
            "learning_rate": [],
            "val_loss": [],
            "val_perplexity": [],
        }
        self.initialized_report = False
        self.last_logged_step = -1

    def on_train_start(self, trainer, pl_module):
        """Initialize the live dashboard on training start."""
        if trainer.global_rank == 0 and not self.initialized_report:
            self._initialize_report()
            self.initialized_report = True
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The `_initialize_report` method (see complete code) creates an HTML/JavaScript dashboard with interactive charts. The callback then calls `flyte.report.log()` every `N` steps to push new metrics. The charts update in real-time so you can watch your loss curve descend while training runs.

There is no need to deploy Grafana, configure Prometheus, or keep a TensorBoard server running. Using `flyte.report.log()` is sufficient to get live training metrics directly in the Flyte UI.

![Metrics viz](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/distributed-llm-pretraining/metrics.png)

### Streaming data at scale

Training datasets are massive. SlimPajama contains 627 billion tokens and spans hundreds of gigabytes even when compressed. Downloading the entire dataset to each training node before starting would take hours and waste storage.

Instead, we convert the data to MDS (MosaicML Data Shard) format and stream it during training:

```
@data_loading_env.task
async def load_and_prepare_streaming_dataset(
    dataset_name: str,
    dataset_config: Optional[str],
    max_length: int,
    train_split: str,
    val_split: Optional[str],
    max_train_samples: Optional[int],
    max_val_samples: Optional[int],
    shard_size_mb: int,
) -> Dir:
    """Tokenize dataset and convert to MDS format for streaming."""
    from datasets import load_dataset
    from streaming import MDSWriter
    from transformers import GPT2TokenizerFast

    output_dir = Path("/tmp/streaming_dataset")
    output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # MDS schema: what each sample contains
    columns = {
        "input_ids": "ndarray:int32",
        "labels": "ndarray:int32",
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This task does three things:

1. **Tokenizes the text** using GPT-2's BPE tokenizer
2. **Concatenates documents** into fixed-length sequences (no padding waste)
3. **Writes shards** to storage in a format optimized for streaming

The task returns a Flyte `Dir` object, which is a reference to the output location. It's not the data itself, just a pointer. When the training task receives this `Dir`, it streams shards on-demand rather than downloading everything upfront.

Flyte caches this task automatically. Run the pipeline twice with the same dataset config, and Flyte skips tokenization entirely on the second run. Change the dataset or sequence length, and it re-runs.

### Distributed training with FSDP

Now we get to the core: actually training the model across multiple GPUs. FSDP is what makes this possible for large models.

```
@distributed_llm_training_env.task(report=True)
def train_distributed_llm(
    prepared_dataset: Dir,
    resume_checkpoint: Optional[Dir],
    vocab_size: int,
    n_positions: int,
    n_embd: int,
    n_layer: int,
    n_head: int,
    batch_size: int,
    num_workers: int,
    max_steps: int,
    learning_rate: float,
    weight_decay: float,
    warmup_steps: int,
    use_fsdp: bool,
    checkpoint_upload_steps: int,
    checkpoint_every_n_steps: int,
    report_every_n_steps: int,
    val_check_interval: int,
    grad_accumulation_steps: int = 1,
) -> Optional[Dir]:
    # ... setup code ...
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Notice `report=True` on the task decorator. It enables Flyte Reports for this specific task.

The training task receives the prepared dataset as a `Dir` and streams data directly from storage:

```
    # StreamingDataset streams shards from the remote Flyte storage on-demand
    # It automatically detects torch.distributed context
    # and shards data across GPUs - each rank gets different data automatically
    train_dataset = StreamingDataset(
        remote=f"{remote_path}/train",  # Remote MDS shard location
        local=str(local_cache / "train"),  # Local cache for downloaded shards
        shuffle=True,  # Shuffle samples
        shuffle_algo="naive",  # Shuffling algorithm
        batch_size=batch_size,  # Used for shuffle buffer sizing
    )

    # Create validation StreamingDataset if it exists
    val_dataset = None
    try:
        val_dataset = StreamingDataset(
            remote=f"{remote_path}/validation",
            local=str(local_cache / "validation"),
            shuffle=False,  # No shuffling for validation
            batch_size=batch_size,
        )
        print(
            f"Validation dataset initialized with streaming from: {remote_path}/validation"
        )
    except Exception as e:
        print(f"No validation dataset found: {e}")

    # Create data loaders
    # StreamingDataset handles distributed sampling internally by detecting
    # torch.distributed.get_rank() and torch.distributed.get_world_size()
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        drop_last=True,  # Drop incomplete batches for distributed training
        collate_fn=mds_collate_fn,  # Handle read-only arrays
    )

    # Create validation loader if validation dataset exists
    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=True,
            drop_last=False,
            collate_fn=mds_collate_fn,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

`prepared_dataset.path` provides the remote storage path for the dataset. MosaicML's `StreamingDataset` automatically shards data across GPUs so that each rank sees different samples, without requiring a manual distributed sampler. The credentials are already in the environment because Flyte set them up.

FSDP is where the memory magic happens. Instead of each GPU holding a full copy of the model (like Distributed Data Parallel (DDP)), FSDP shards the parameters, gradients, and optimizer states across all GPUs. Each GPU only holds 1/8th of the model. When a layer needs to run, FSDP all-gathers the full parameters, runs the computation, then discards them.

```
    # Configure distributed strategy
    if use_fsdp:
        from torch.distributed.fsdp.wrap import ModuleWrapPolicy

        strategy = FSDPStrategy(
            auto_wrap_policy=ModuleWrapPolicy([GPTBlock]),
            activation_checkpointing_policy=None,
            cpu_offload=False,  # H200 has 141GB - no CPU offload needed
            state_dict_type="full",
            sharding_strategy="FULL_SHARD",
            process_group_backend="nccl",
        )
    else:
        from lightning.pytorch.strategies import DDPStrategy

        strategy = DDPStrategy(process_group_backend="nccl")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

We wrap at the `GPTBlock` level because each transformer block becomes an FSDP unit. This balances communication overhead (more units = more all-gathers) against memory savings (smaller units = more granular sharding).

One subtle detail: gradient clipping. With FSDP, gradients are sharded across ranks, so computing a global gradient norm would require an expensive all-reduce operation. Instead of norm-based clipping, we use value-based gradient clipping, which clamps each individual gradient element to a fixed range. This can be done independently on each rank with no coordination overhead and is commonly used for large-scale FSDP training.

```
    # Initialize trainer
    trainer = L.Trainer(
        strategy=strategy,
        accelerator="gpu",
        devices=DEVICES_PER_NODE,
        num_nodes=NUM_NODES,
        # Training configuration
        max_steps=max_steps,
        precision="bf16-mixed",  # BFloat16 for better numerical stability
        # Optimization
        gradient_clip_val=1.0,
        gradient_clip_algorithm=(
            "value" if use_fsdp else "norm"
        ),  # FSDP requires 'value', DDP can use 'norm'
        accumulate_grad_batches=grad_accumulation_steps,
        # Logging and checkpointing
        callbacks=callbacks,
        log_every_n_steps=report_every_n_steps,
        val_check_interval=val_check_interval,
        # Performance
        benchmark=True,
        deterministic=False,
        # Enable gradient checkpointing for memory efficiency
        enable_checkpointing=True,
        use_distributed_sampler=False,  # StreamingDataset handles distributed sampling
    )

    # Train the model (resume from checkpoint if provided)
    trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)

    # Print final results
    if trainer.global_rank == 0:
        if val_loader is not None:
            print(
                f"Final validation loss: {trainer.callback_metrics.get('val/loss', 0.0):.4f}"
            )
            print(
                f"Final validation perplexity: {trainer.callback_metrics.get('val/perplexity', 0.0):.4f}"
            )
        print(f"Checkpoints saved to: {checkpoint_dir}")

        return Dir.from_local_sync(output_dir)

    return None
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The trainer configuration brings together all the pieces we've discussed:

- **`precision="bf16-mixed"`**: BFloat16 mixed precision training. BF16 has the same dynamic range as FP32 (unlike FP16), so you don't need loss scaling. This is the standard choice for modern GPU training.
- **`gradient_clip_val=1.0`**: Clips gradients to prevent exploding gradients during training. Combined with value-based clipping for FSDP compatibility.
- **`accumulate_grad_batches`**: Accumulates gradients over multiple forward passes before updating weights. This lets us hit a larger effective batch size than what fits in GPU memory.
- **`val_check_interval`**: How often to run validation. For long training runs, you don't want to validate every epoch — that would be too infrequent. Instead, validate every `N` training steps.
- **`use_distributed_sampler=False`**: We disable Lightning's built-in distributed sampler because `StreamingDataset` handles data sharding internally. Using both would cause conflicts.
- **`benchmark=True`**: Enables cuDNN autotuning. PyTorch will benchmark different convolution algorithms on the first batch and pick the fastest one for your specific input sizes.

The trainer then calls `fit()` with the model, data loaders, and optionally a checkpoint path to resume from.

### Tying it together

The pipeline task orchestrates everything:

```
@driver_env.task
async def distributed_llm_pipeline(
    model_size: str,
    dataset_name: str = "Salesforce/wikitext",
    dataset_config: str = "wikitext-103-raw-v1",
    max_length: int = 2048,
    max_train_samples: Optional[int] = 10000,
    max_val_samples: Optional[int] = 1000,
    max_steps: int = 100000,
    resume_checkpoint: Optional[Dir] = None,
    checkpoint_upload_steps: int = 1000,
    # Optional overrides (if None, uses model preset defaults)
    batch_size: Optional[int] = None,
    learning_rate: Optional[float] = None,
    use_fsdp: bool = True,
) -> Optional[Dir]:
    # Get model configuration
    model_config = get_model_config(model_size)

    # Use preset values if not overridden
    actual_batch_size = (
        batch_size if batch_size is not None else model_config["batch_size"]
    )
    actual_learning_rate = (
        learning_rate if learning_rate is not None else model_config["learning_rate"]
    )

    # Step 1: Load and prepare streaming dataset
    prepared_dataset = await load_and_prepare_streaming_dataset(
        dataset_name=dataset_name,
        dataset_config=dataset_config,
        max_length=max_length,
        train_split="train",
        val_split="validation",
        max_train_samples=max_train_samples,
        max_val_samples=max_val_samples,
        shard_size_mb=64,  # 64MB shards
    )

    # Step 2: Run distributed training
    if resume_checkpoint is not None:
        print("\nStep 2: Resuming distributed training from checkpoint...")
    else:
        print("\nStep 2: Starting distributed training from scratch...")

    target_global_batch = 256
    world_size = NUM_NODES * DEVICES_PER_NODE
    effective_per_step = world_size * actual_batch_size
    grad_accumulation_steps = max(
        1, math.ceil(target_global_batch / max(1, effective_per_step))
    )

    checkpoint_dir = train_distributed_llm(
        prepared_dataset=prepared_dataset,
        resume_checkpoint=resume_checkpoint,
        vocab_size=VOCAB_SIZE,
        n_positions=N_POSITIONS,
        n_embd=model_config["n_embd"],
        n_layer=model_config["n_layer"],
        n_head=model_config["n_head"],
        batch_size=actual_batch_size,
        num_workers=8,
        max_steps=max_steps,
        learning_rate=actual_learning_rate,
        weight_decay=0.1,
        warmup_steps=500,
        use_fsdp=use_fsdp,
        checkpoint_upload_steps=checkpoint_upload_steps,
        checkpoint_every_n_steps=model_config["checkpoint_every_n_steps"],
        report_every_n_steps=model_config["report_every_n_steps"],
        val_check_interval=model_config["val_check_interval"],
        grad_accumulation_steps=grad_accumulation_steps,
    )

    return checkpoint_dir
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The flow is straightforward: load the configuration, prepare the data, and run training. Flyte automatically manages the execution graph so data preparation runs first and training waits until it completes. If data preparation is cached from a previous run, training starts immediately.

The gradient accumulation calculation is worth noting. We want a global batch size of 256 (this affects training dynamics), but each GPU can only fit a small batch. With 8 GPUs and batch size 1 each, we need 32 accumulation steps to hit 256.

## Running the pipeline

With everything defined, running is simple:

```
if __name__ == "__main__":
    flyte.init_from_config()

    run = flyte.run(
        distributed_llm_pipeline,
        model_size="30B",
        dataset_name="cerebras/SlimPajama-627B",
        dataset_config=None,
        max_length=2048,
        max_train_samples=5_000_000,
        max_val_samples=50_000,
        max_steps=15_000,
        use_fsdp=True,
        checkpoint_upload_steps=1000,
    )

    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This configuration is designed for testing and demonstration. Notice `max_train_samples=5_000_000` — that's 5 million samples from a dataset with 627 billion tokens. A tiny fraction, enough to verify everything works without burning through compute.

For a real pretraining run, you would remove this limit by setting `max_train_samples=None`, or increase it significantly. You would also increase `max_steps` to match your compute budget, likely scale to multiple nodes by setting `NUM_NODES=4` or higher, and allocate more resources. The rest of the pipeline remains unchanged.

```bash
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run train.py
```

When you run this, Flyte:

1. **Builds the container** (cached after first run)
2. **Schedules data prep** on CPU nodes
3. **Waits for data prep** (or skips if cached)
4. **Provisions H200 nodes** and launches distributed training
5. **Streams logs and metrics** to the UI in real-time

Open the Flyte UI to observe the workflow execution. The data preparation task completes first, followed by the training task spinning up. As training begins, the Flyte Reports dashboard starts plotting loss curves. If anything goes wrong, the logs are immediately available in the UI.

![Training Log](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/distributed-llm-pretraining/logs.png)

If training fails due to an out-of-memory error, a GPU driver error, or a hardware issue, check the logs, fix the problem, and restart the run with `resume_checkpoint` pointing to the most recent checkpoint. Training resumes from where it left off. Flyte tracks the full execution history, so it is easy to see exactly what happened.

## Going further

If you've run through this tutorial, here's where to go next depending on what you're trying to do:

**You want to train on your own data.** The data prep task accepts any HuggingFace dataset with a `text` column. If your data isn't on HuggingFace, you can modify `load_and_prepare_streaming_dataset` to read from S3, local files, or any other source. The key is getting your data into MDS format. Once it's there, the streaming and sharding just works. For production training, look at SlimPajama, [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T), or [The Pile](https://huggingface.co/datasets/EleutherAI/pile) as starting points.

**You want to scale to more GPUs.** Bump `NUM_NODES` and Flyte handles the rest. The main thing to watch is the effective batch size. As you add more GPUs, you may want to reduce gradient accumulation steps to keep the same global batch size, or increase them if you want to experiment with larger batches.

**Your training keeps failing.** Add `retries=3` to your task decorator for automatic retry on transient failures. This handles spot instance preemption, temporary network issues, and the occasional GPU that decides to stop working. Combined with checkpointing, you get fault-tolerant training that can survive most infrastructure hiccups. For persistent failures, the Flyte UI logs are your friend as they capture stdout/stderr from all ranks.

**You want better visibility into what's happening.** We're actively working on surfacing GPU driver logs (xid/sxid errors), memory utilization breakdowns, and NCCL communication metrics directly in the Flyte UI. If you're hitting issues that the current logs don't explain, reach out. Your feedback helps us prioritize what observability features to build next!

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/computer-vision ===

# Computer vision

Tutorials for image and vision-language model workloads.

### **Computer vision > Fine-tuning a VLM**

Adapt Qwen2.5-VL to occluded image classification by training a 10K-parameter adapter with multi-node DeepSpeed, automatic recovery, and live training dashboards.

### **Computer vision > RT-DETR object detection**

Fine-tune RT-DETRv2 on a COCO dataset with live training charts, mAP evaluation, and bounding-box demos.

### **Computer vision > Multimodal retrieval evaluation**

Benchmark ColPali, SigLIP, and OCR+BM25 visual document retrieval on ViDoRe with warm GPU containers, dynamic batching, and an interactive report.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/computer-vision/qwen-vl-finetuning ===

# Fine-tuning a VLM

Large vision-language models like Qwen2.5-VL are remarkably capable out of the box. But adapting one to a specialized task raises an immediate question: do you really need to update 3 billion parameters?

Usually, no. The **frozen backbone pattern** is a practical alternative: keep all pretrained weights frozen and train only a small, task-specific adapter inserted before the vision encoder. The adapter learns to transform its input in a way that makes the frozen model perform well on your task without touching the underlying billions of parameters. The result is faster training, lower memory pressure, and a much smaller set of weights to store and version.

This tutorial makes that pattern concrete. We take a partially-occluded image classification task — CIFAR-10 images with random black rectangles covering 22–45% of the frame — and train a tiny Conv2d adapter to "see through" the occlusion before the frozen VLM processes it. The adapter has approximately **10,500 trainable parameters**. The backbone has 3 billion.

The machine learning is interesting, but the real focus here is on shipping a production-grade training pipeline:

- **Multi-node distributed training** across 2 nodes × 4 GPUs using PyTorch Elastic and DeepSpeed Stage 2
- **Automatic fault tolerance**: checkpoints upload to object storage after every validation epoch; if training fails, the pipeline returns the last known-good checkpoint instead of crashing
- **Live observability**: a streaming HTML dashboard in the Flyte UI updates in real-time as training runs, no separate monitoring infrastructure required
- **Cached data preparation**: dataset processing runs once and is reused across all reruns
- **Clean task isolation**: each stage runs with exactly the resources it needs, nothing more

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning).

## Overview

The pipeline has four tasks with clearly defined responsibilities:

1. **Dataset preparation** (`prepare_occlusion_dataset`): Downloads CIFAR-10, applies random occlusions, and writes image manifests as streaming JSONL files to object storage. Runs on CPU and is cached, so it only runs once regardless of how many times you rerun the pipeline with the same config.
2. **Multi-node training** (`train_qwen_adapter_multinode`): Runs PyTorch Lightning with DeepSpeed Stage 2 across 2 nodes × 4 L40s GPUs. Only the adapter trains; the 3B backbone stays frozen.
3. **Evaluation** (`evaluate_qwen_adapter`): Loads the saved adapter, runs inference on validation examples, and produces a predictions report. Runs on a single GPU.
4. **Driver** (`qwen_vl_multinode_deepspeed`): The pipeline entry point. Orchestrates the three tasks above, manages WandB initialization, handles recovery from training failures, and produces a final HTML report in the Flyte UI.

Why this separation? It mirrors how production pipelines should be structured. Data prep is cheap and deterministic so we cache it. Training is expensive and failure-prone so we isolate it with fault tolerance. Evaluation needs different hardware than training. The driver is pure coordination, so it gets minimal resources.

## Implementation

### Setting up the environment

Different tasks need different compute. Flyte's `TaskEnvironment` is how you declare exactly what each task needs.

First, define the container images. Training needs a full CUDA stack with ML libraries, driver compatibility, and DeepSpeed's build tools:

```
gpu_image = (
    flyte.Image.from_base("nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04")
    .clone(name="qwen_vl_multinode_deepspeed", python_version=(3, 13), extendable=True)
    .with_apt_packages("build-essential")
    .with_pip_packages(
        "torch==2.9.1",
        "torchvision==0.24.1",
        "lightning==2.6.1",
        "transformers==4.57.3",
        "deepspeed==0.18.8",
        "datasets==4.4.1",
        "pillow==11.3.0",
        "flyteplugins-pytorch>=2.0.11",
        "flyteplugins-jsonl>=2.0.11",
        "flyteplugins-wandb>=2.0.11",
    )
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

`from_base` starts from the official NVIDIA CUDA image, giving you NCCL, cuDNN, and the right driver headers out of the box. `with_apt_packages("build-essential")` is required because DeepSpeed compiles CUDA kernels at first use and without build tools, it silently falls back to slower CPU implementations. The non-GPU image for data preparation and orchestration is much lighter:

```
non_gpu_image = flyte.Image.from_debian_base(
    name="qwen_vl_multinode_deepspeed_non_gpu"
).with_pip_packages(
    "flyteplugins-pytorch>=2.0.11",
    "flyteplugins-jsonl>=2.0.11",
    "flyteplugins-wandb>=2.0.11",
    "lightning==2.6.1",
    "datasets==4.4.1",
    "pillow==11.3.0",
    "torchvision==0.24.1",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

With images defined, each task gets its own resource declaration:

```
dataset_env = flyte.TaskEnvironment(
    name="qwen_vl_dataset_prep",
    image=non_gpu_image,
    resources=flyte.Resources(cpu=5, memory="15Gi"),
    cache="auto",
)

training_env = flyte.TaskEnvironment(
    name="qwen_vl_multinode_training",
    image=gpu_image,
    resources=flyte.Resources(
        cpu=42,
        memory="256Gi",
        gpu=f"L40s:{DEVICES_PER_NODE}",
        shm="16Gi",
    ),
    plugin_config=Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE),
    secrets=[
        flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")
    ],  # TODO: update with your own secret key
    env_vars={
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
        "NCCL_DEBUG": "WARN",
        "TOKENIZERS_PARALLELISM": "false",
        "CUDA_HOME": "/usr/local/cuda",
        "DS_SKIP_CUDA_CHECK": "1",
    },
)

evaluation_env = flyte.TaskEnvironment(
    name="qwen_vl_adapter_eval",
    image=gpu_image,
    resources=flyte.Resources(cpu=16, memory="64Gi", gpu="L40s:1"),
    cache="auto",
)

driver_env = flyte.TaskEnvironment(
    name="qwen_vl_multinode_driver",
    image=non_gpu_image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, evaluation_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

A few things worth noting here:

- **`Elastic(nnodes=2, nproc_per_node=4)`**: Flyte's integration with PyTorch's elastic launch. It handles process spawning (one process per GPU), rank assignment, and distributed environment setup — master address, world size, rendezvous — without any shell scripting or manual `torchrun` invocations.
- **`shm="16Gi"`**: Shared memory is required for NCCL inter-GPU communication on the same node. Without it, you'll see cryptic errors from the communication library when training starts.
- **`cache="auto"`**: The dataset preparation task is cached by input hash. Running the pipeline twice with the same hyperparameters skips it entirely on the second run.
- **`depends_on`**: The driver task declares that each worker image must finish building before it starts, ensuring containers are ready before the driver begins orchestrating.
- **`secrets`**: The WandB API key is injected from Flyte's secret store as an environment variable. No credentials in code.

All training hyperparameters flow through a single typed dataclass:

```
@dataclass
class Config:
    model_name: str = DEFAULT_MODEL_NAME
    image_size: int = IMAGE_SIZE
    max_train_samples: int = 1024
    max_val_samples: int = 256
    epochs: int = 8
    per_device_batch_size: int = 1
    target_global_batch_size: int = 16
    learning_rate: float = 2e-4
    weight_decay: float = 1e-2
    reconstruction_loss_weight: float = 0.35
    report_every_n_steps: int = 10
    num_workers: int = 4
    max_length: int = 512
    eval_examples: int = 16
    train_occlusion_min: float = 0.22
    train_occlusion_max: float = 0.42
    eval_occlusion_min: float = 0.28
    eval_occlusion_max: float = 0.45
    seed: int = 7

    def to_dict(self) -> dict:
        return asdict(self)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

Using a dataclass rather than scattered constants or argparse arguments means the full config is serializable, can be stored in artifact metadata alongside the model checkpoint, and flows cleanly as a typed input between tasks. The `to_dict()` method serializes it for WandB logging.

### Preparing the dataset

The dataset task handles everything: downloading CIFAR-10, generating occlusions, and writing the manifests.

```
@dataset_env.task
async def prepare_occlusion_dataset(config: Config) -> DatasetArtifacts:
    from PIL import Image
    from torchvision.datasets import CIFAR10
    from flyte.io import Dir
    from flyteplugins.jsonl import JsonlFile
    import random

    rng = random.Random(config.seed)

    images_dir = Path("/tmp/qwen_vl_occlusion_images")
    train_images_dir = images_dir / "train" / "images"
    val_images_dir = images_dir / "validation" / "images"
    train_images_dir.mkdir(parents=True, exist_ok=True)
    val_images_dir.mkdir(parents=True, exist_ok=True)

    prompt = (
        "The image may be partially occluded. "
        "Answer with exactly one CIFAR-10 class label: "
        + ", ".join(CLASS_NAMES)
        + ". What is the main object?"
    )

    async def export_split(
        dataset,
        split_name: str,
        limit: int,
        local_image_dir: Path,
        occ_min: float,
        occ_max: float,
    ):
        out = JsonlFile.new_remote(f"{split_name}_manifest.jsonl")
        async with out.writer() as writer:
            for idx in range(limit):
                pil_image, label_idx = dataset[idx]
                resized = pil_image.resize(
                    (config.image_size, config.image_size),
                    resample=Image.Resampling.BICUBIC,
                )
                rel_path = f"{split_name}/images/{split_name}-{idx:05d}.png"
                resized.save(local_image_dir / f"{split_name}-{idx:05d}.png")
                occlusion = build_occlusion_box(
                    width=config.image_size,
                    height=config.image_size,
                    rng=rng,
                    min_fraction=occ_min,
                    max_fraction=occ_max,
                )
                await writer.write(
                    {
                        "image_path": rel_path,
                        "label": CLASS_NAMES[label_idx],
                        "label_index": int(label_idx),
                        "prompt": prompt,
                        "occlusion": occlusion,
                    }
                )
        return out

    train_dataset = CIFAR10(root="/tmp/cifar10", train=True, download=True)
    val_dataset = CIFAR10(root="/tmp/cifar10", train=False, download=True)

    train_manifest = await export_split(
        train_dataset,
        "train",
        config.max_train_samples,
        train_images_dir,
        config.train_occlusion_min,
        config.train_occlusion_max,
    )
    val_manifest = await export_split(
        val_dataset,
        "validation",
        config.max_val_samples,
        val_images_dir,
        config.eval_occlusion_min,
        config.eval_occlusion_max,
    )

    return DatasetArtifacts(
        train_manifest=train_manifest,
        val_manifest=val_manifest,
        images=await Dir.from_local(str(images_dir)),
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/data.py*

Each image gets a randomly-placed black rectangle. The occlusion covers 22–42% of the image area during training and 28–45% during evaluation. The occlusion is deliberately harder at eval time to test how robust the adapter is. The bounding box coordinates are written into each manifest record alongside the image path and ground-truth label, so the training task can reconstruct the binary occlusion mask as the adapter's fourth input channel.

Two Flyte primitives handle data persistence without any manual storage management:

- **`JsonlFile.new_remote()`** opens a streaming writer that writes directly to remote object storage. The training task reads records back via `jf.iter_records_sync()`, so no local file paths and S3 credentials to manage.
- **`Dir.from_local()`** uploads the local images directory to object storage and returns a typed handle. The training task downloads it to a local path via `Dir.download_sync()`.

Because `cache="auto"` is set on this task, dataset preparation runs once. Subsequent reruns with the same config skip it entirely.

### The adapter

Here's the entire trainable component of the model with `~10,500` parameters:

```
class ResidualOcclusionAdapter(nn.Module):
    def __init__(self, hidden_channels: int = 32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, 3, kernel_size=1),
            nn.Tanh(),
        )
        self.gate = nn.Parameter(torch.tensor(0.10))

    def forward(
        self, pixel_values: torch.Tensor, occlusion_mask: torch.Tensor
    ) -> torch.Tensor:
        if pixel_values.ndim != 4:
            raise ValueError(
                "ResidualOcclusionAdapter expects dense image tensors with shape "
                f"(B, C, H, W), but received {tuple(pixel_values.shape)}."
            )
        if occlusion_mask.ndim == 3:
            occlusion_mask = occlusion_mask.unsqueeze(1)
        adapter_input = torch.cat(
            [pixel_values, occlusion_mask.to(pixel_values.dtype)],
            dim=1,
        )
        residual = self.net(adapter_input)
        return pixel_values + torch.tanh(self.gate) * residual
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The adapter takes the occluded image (3 channels) concatenated with the binary occlusion mask (1 channel) as a 4-channel input. It predicts a residual correction through a small convolutional network, then adds that correction back to the original pixels. The learnable `gate` scalar, initialized to `0.10`, controls how strongly the adapter modifies the image. It starts as a near-identity transformation and gradually grows during training as the adapter gains confidence.

The adapter is plugged into Qwen2.5-VL via a Lightning module:

```
class QwenVLAdapterModule(L.LightningModule):
    def __init__(
        self,
        model_name: str,
        learning_rate: float,
        weight_decay: float,
        reconstruction_loss_weight: float,
    ):
        super().__init__()
        from transformers import Qwen2_5_VLForConditionalGeneration

        self.save_hyperparameters()
        self.adapter = ResidualOcclusionAdapter()

        self.backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa",
        )
        self.backbone.requires_grad_(False)
        self.backbone.gradient_checkpointing_enable()

        # DeepSpeed checkpoints only persist the trainable adapter weights when
        # `exclude_frozen_parameters=True`. On resume we rebuild the frozen
        # backbone from Hugging Face and load the checkpoint non-strictly.
        self.strict_loading = False

        self.total_params, self.trainable_params = count_parameters(self)
        self.example_input_array = None
        self.vision_patch_size = int(self.backbone.config.vision_config.patch_size)
        self.temporal_patch_size = int(
            getattr(self.backbone.config.vision_config, "temporal_patch_size", 1)
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The key line is `self.backbone.requires_grad_(False)`. This freezes all 3 billion backbone parameters which means only the adapter's ~10,500 weights receive gradients. `gradient_checkpointing_enable()` trades compute for memory: instead of keeping the frozen backbone's intermediate activations in GPU memory during the backward pass, they're recomputed on the fly. This is critical when a 3B model is sitting in GPU memory alongside your optimizer state.

`strict_loading = False` handles an important DeepSpeed checkpoint detail. When `exclude_frozen_parameters=True` is set on the strategy, DeepSpeed only saves the adapter weights in checkpoints, not the 3B frozen backbone. On resume, the checkpoint won't contain backbone weights, so loading must be non-strict. The `on_load_checkpoint` hook fills in the missing backbone weights from the freshly-loaded HuggingFace model, combining the best of both worlds: small checkpoints and a fully initialized model.

The training loss combines two objectives:

```
    def _forward_losses(
        self, batch: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        backbone_dtype = next(self.backbone.parameters()).dtype
        if batch["pixel_values"].ndim == 2:
            if "image_grid_thw" not in batch:
                raise ValueError(
                    "Packed Qwen pixel values require `image_grid_thw` to reconstruct "
                    "dense images for the Conv2d adapter."
                )
            grid_thw = batch["image_grid_thw"]
            dense_pixels = packed_pixels_to_dense_images(
                batch["pixel_values"].to(dtype=backbone_dtype),
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
            clean_pixels = packed_pixels_to_dense_images(
                batch["clean_pixel_values"].to(dtype=backbone_dtype),
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
            adapted_dense = self.adapter(dense_pixels, batch["occlusion_mask"])
            adapted_pixels = dense_images_to_packed_pixels(
                adapted_dense,
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
        else:
            clean_pixels = batch["clean_pixel_values"].to(dtype=backbone_dtype)
            adapted_dense = self.adapter(
                batch["pixel_values"].to(dtype=backbone_dtype),
                batch["occlusion_mask"],
            )
            adapted_pixels = adapted_dense

        forward_kwargs = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"],
            "pixel_values": adapted_pixels,
            "labels": batch["labels"],
        }
        if "image_grid_thw" in batch:
            forward_kwargs["image_grid_thw"] = batch["image_grid_thw"]
        outputs = self.backbone(**forward_kwargs)

        clean_pixels = clean_pixels.to(
            device=adapted_pixels.device, dtype=backbone_dtype
        )
        occlusion_mask = batch["occlusion_mask"].to(
            device=adapted_pixels.device,
            dtype=backbone_dtype,
        )
        if occlusion_mask.ndim == 3:
            occlusion_mask = occlusion_mask.unsqueeze(1)
        if occlusion_mask.shape[-2:] != adapted_dense.shape[-2:]:
            occlusion_mask = F.interpolate(
                occlusion_mask,
                size=adapted_dense.shape[-2:],
                mode="nearest",
            )

        reconstruction_error = (adapted_dense - clean_pixels).abs() * occlusion_mask
        mask_denominator = (occlusion_mask.sum() * adapted_dense.shape[1]).clamp_min(
            1.0
        )

        reconstruction_loss = reconstruction_error.sum() / mask_denominator
        total_loss = (
            outputs.loss + self.hparams.reconstruction_loss_weight * reconstruction_loss
        )

        return {
            "total_loss": total_loss,
            "lm_loss": outputs.loss,
            "reconstruction_loss": reconstruction_loss,
        }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The **language modeling loss** (cross-entropy on the predicted class label tokens) drives the model to produce correct answers. The **reconstruction loss** (mean absolute error between the adapter's output and the clean image, computed only in the occluded region) pushes the adapter to actually restore the missing pixels rather than finding a representation shortcut. Without it, the adapter could overfit the frozen backbone's quirks and produce correct tokens while generating noise in the masked region. The `reconstruction_loss_weight` (default `0.35`) balances these two objectives.

Because Qwen2.5-VL's preprocessor packs image patches into a flat `(num_patches, patch_dim)` tensor, the adapter must unpack this into a spatial `(B, C, H, W)` tensor, apply the convolutions, then repack. The `packed_pixels_to_dense_images` and `dense_images_to_packed_pixels` utilities in `model.py` handle this format conversion transparently.

### Multi-node training with DeepSpeed

The training task is a standard PyTorch Lightning training loop with distributed infrastructure handled by Flyte and DeepSpeed:

```
@wandb_init
@training_env.task(report=True)
def train_qwen_adapter_multinode(
    train_manifest: JsonlFile,
    val_manifest: JsonlFile,
    images_dir: Dir,
    config: Config,
    resume_from: Optional[Dir] = None,
    recovery_uri: Optional[str] = None,
) -> Optional[Dir]:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The `@wandb_init` decorator integrates with the `wandb_config` context created in the driver task. It retrieves the initialized WandB run and attaches a `WandbLogger` to the trainer. The `report=True` flag on the task decorator enables Flyte Reports for live dashboard streaming from this task.

![Live Training](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/live_training_graph.png)
![Live Training Contd](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/losses.png)

DeepSpeed Stage 2 shards optimizer states and gradients across GPUs, reducing per-GPU memory usage significantly. The critical configuration flag here is `exclude_frozen_parameters=True`:

```
    strategy = DeepSpeedStrategy(
        stage=2,
        offload_optimizer=False,
        offload_parameters=False,
        process_group_backend="nccl",
        exclude_frozen_parameters=True,
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

Without `exclude_frozen_parameters=True`, DeepSpeed would shard and checkpoint the frozen backbone weights too, producing enormous checkpoint files, slow checkpoint saves, and unnecessary communication overhead. With it, only the adapter participates in sharding and checkpointing. The backbone is loaded independently on each worker from HuggingFace.

Gradient accumulation is computed automatically to hit the target global batch size regardless of how many GPUs are actually running:

```
    world_size = NUM_NODES * DEVICES_PER_NODE
    per_step_batch = world_size * config.per_device_batch_size
    grad_accum_steps = max(
        1,
        math.ceil(config.target_global_batch_size / max(1, per_step_batch)),
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

With 2 nodes × 4 GPUs × per-device batch size 1, the effective per-step batch is 8. To reach the default target of 16, the trainer accumulates over 2 steps. Change `NUM_NODES` or `per_device_batch_size` and the calculation adjusts automatically.

The trainer brings everything together:

```
    trainer = L.Trainer(
        accelerator="gpu",
        devices=DEVICES_PER_NODE,
        num_nodes=NUM_NODES,
        strategy=strategy,
        logger=wandb_logger,
        precision="bf16-mixed",
        max_epochs=config.epochs,
        accumulate_grad_batches=grad_accum_steps,
        callbacks=[
            checkpoint_callback,
            metrics_callback,
            recovery_callback,
            live_report_callback,
        ],
        gradient_clip_val=1.0,
        benchmark=True,
        log_every_n_steps=1,
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

`precision="bf16-mixed"` uses BFloat16, which matches FP32's dynamic range (unlike FP16), so you don't need loss scaling. This is the standard choice for modern VLM training. `benchmark=True` runs cuDNN autotuning on the first batch to select the fastest kernels for your specific input sizes.

### Fault tolerance and recovery

Multi-node GPU jobs fail. Hardware hiccups, spot instance preemptions, NCCL timeouts, memory spikes, etc. and the question is when, not if. This pipeline handles it with a two-part system.

After every validation epoch, the `RecoveryArtifactCallback` calls `trainer.save_checkpoint()` to write a DeepSpeed checkpoint directory, then uploads all shard files to the recovery URI. Each node's local rank 0 uploads its own shards; global rank 0 uploads the metadata files (`metrics.json`, `summary.json`). A distributed barrier between save and upload ensures all workers finish before training continues.

If training fails, the driver task catches the error and returns the last recovery artifact instead of propagating the failure:

```
    try:
        with wandb_config(
            project=wandb_project,
            entity=wandb_entity,
        ):
            training_artifacts = train_qwen_adapter_multinode(
                train_manifest=train_manifest,
                val_manifest=val_manifest,
                images_dir=images,
                config=config,
                resume_from=resume_training_artifacts,
                recovery_uri=recovery_uri,
            )
    except flyte.errors.RuntimeUserError as e:
        if recovery_uri is None:
            raise e
        print(f"Training failed - recovering latest checkpoint bundle: {recovery_uri}")
        try:
            recovered_artifacts = Dir(path=recovery_uri)
            recovered_root = await download_dir_async(recovered_artifacts)
            flyte.report.log(
                build_qwen_adapter_report_html(recovered_root, None),
                do_flush=True,
            )
            return recovered_artifacts
        except Exception:
            raise e
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

A failed run still produces useful output: the best checkpoint reached before the failure, along with a partial training report. To resume from that point, pass the recovery artifact as `resume_training_artifacts` on the next run. The training task downloads it, finds the most recent `.ckpt` file, and passes it to `trainer.fit()` as `ckpt_path`. Training picks up at the last saved epoch with optimizer state and metrics history intact.

The recovery URI is constructed from the configurable base path and the run name:

```
s3://your-bucket/qwen-vl-multinode-deepspeed/<run-name>/qwen_vl_training_recovery/
```

This means each run gets its own recovery location, so you can identify exactly which run a checkpoint came from.

### Live observability

`flyte.report` lets you push HTML content directly into the Flyte UI during task execution, with no separate monitoring infrastructure. The `LiveTrainingReportCallback` uses this to stream training metrics in real-time:

```
    def _push_update(
        self,
        *,
        trainer,
        pl_module,
        status: str,
        phase: str,
        train_total=None,
        train_lm=None,
        train_recon=None,
        val_total=None,
        note: str,
    ) -> None:
        adapter_gate = float(torch.tanh(pl_module.adapter.gate).detach().cpu())

        def fmt(value):
            return f"{float(value):.4f}" if value is not None else "-"

        payload = {
            "step": trainer.global_step,
            "phase": phase,
            "train_total": fmt(train_total),
            "train_lm": fmt(train_lm),
            "train_recon": fmt(train_recon),
            "val_total": fmt(val_total),
            "train_total_value": (
                float(train_total) if train_total is not None else None
            ),
            "val_total_value": float(val_total) if val_total is not None else None,
            "adapter_gate": f"{adapter_gate:.4f}",
            "status": status,
            "resumed_from": self.resumed_from or "fresh run",
            "recovery_path": self.recovery_callback.latest_path
            or "pending first checkpoint upload",
            "note": note,
        }
        flyte.report.log(
            f"""
            <script>
            if (typeof window.updateQwenLiveReport === "function") {{
                window.updateQwenLiveReport({json.dumps(payload)});
            }}
            </script>
            """,
            do_flush=True,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/callbacks.py*

`on_train_start` (see the full code) initializes the dashboard with an SVG loss chart and an HTML metrics table. Every `report_every_n_steps` training steps, `_push_update` serializes the latest metrics into a `<script>` block and calls `flyte.report.log()` to append it to the live page. The JavaScript `updateQwenLiveReport()` function then updates the chart polylines and appends a new table row for each step.

For resumed runs, the prior metrics history is seeded into the table on `on_train_start`, so the metrics view is continuous across runs rather than restarting from zero.

![Recovery](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/recovery.png)

WandB metrics are logged in parallel by `AdapterMetricsCallback` after each validation epoch, including per-epoch train and validation losses, the LM loss component, the reconstruction loss component, and the current adapter gate value.

![WandB](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/wandb.png)

### Evaluation

After training completes, a separate task runs inference on a single GPU:

```
@evaluation_env.task
async def evaluate_qwen_adapter(
    val_manifest: JsonlFile,
    images_dir: Dir,
    training_artifacts: Dir,
    config: Config,
) -> Dir:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The task is async so the driver can `asyncio.gather` the downloads of training artifacts and images in parallel rather than sequentially, a simple speedup that matters when downloading hundreds of megabytes from object storage.

The evaluation task loads the adapter state dict from `adapter_artifact.pt`, rebuilds the frozen backbone fresh from HuggingFace (there's no need to checkpoint 3B weights, only the ~10,500 adapter weights travel with the artifact), and runs greedy decoding on each validation example. The metric is exact-match accuracy between the model's predicted class token and the ground-truth CIFAR-10 label.

### Putting it all together

The driver task is the pipeline entry point that all other tasks flow through:

```
@driver_env.task(report=True)
async def qwen_vl_multinode_deepspeed(
    model_name: str = DEFAULT_MODEL_NAME,
    max_train_samples: int = 1024,
    max_val_samples: int = 256,
    epochs: int = 8,
    per_device_batch_size: int = 1,
    target_global_batch_size: int = 16,
    learning_rate: float = 2e-4,
    reconstruction_loss_weight: float = 0.35,
    eval_examples: int = 16,
    resume_training_artifacts: Optional[Dir] = None,
    checkpoint_base_uri: Optional[str] = DEFAULT_CHECKPOINT_BASE_URI,
    wandb_project: str = "qwen-vl-multinode-deepspeed",
    wandb_entity: Optional[str] = None,
) -> Optional[Dir]:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The driver constructs the recovery URI from `checkpoint_base_uri` and the current run name, prepares the dataset (or retrieves it from cache), then executes training inside a `wandb_config` context. The `wandb_config` context manager creates and registers a WandB run; the `@wandb_init` decorator on the training task retrieves it, updates it with the full `Config` dataclass, and attaches a `WandbLogger`. Neither the training task nor the callbacks need any WandB initialization code of their own.

The recovery handler (shown in the previous section) wraps the training call. If training succeeds, evaluation runs next. The driver then downloads both the training and evaluation artifacts concurrently and assembles a final HTML report with training curves, evaluation summary, per-epoch metrics table, and sample prediction cards with the actual occluded images, which is pushed to Flyte Reports.

## Running the tutorial

Before running, update two placeholders in `config.py`:

- `DEFAULT_CHECKPOINT_BASE_URI`: your S3, GCS, or Azure Blob URI for checkpoint storage
- The `wandb_api_key` secret key name to match your cluster's secret store configuration

Then configure and launch:

```
if __name__ == "__main__":
    flyte.init_from_config()

    run = flyte.run(
        qwen_vl_multinode_deepspeed,
        model_name=DEFAULT_MODEL_NAME,
        max_train_samples=512,
        max_val_samples=128,
        epochs=5,
        per_device_batch_size=1,
        target_global_batch_size=16,
        learning_rate=2e-4,
        reconstruction_loss_weight=0.35,
        eval_examples=16,
        checkpoint_base_uri=DEFAULT_CHECKPOINT_BASE_URI,
        wandb_project="qwen-vl-multinode-deepspeed",
        wandb_entity="<YOUR_WANDB_ENTITY>",  # TODO: update with your own wandb entity
        # resume_training_artifacts=Dir(
        #     path="s3://flyte-examples/qwen-vl-multinode-deepspeed/<ACTION_NAME>/qwen_vl_training_recovery/"
        # ),
    )

    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/train.py*

```bash
flyte create config --endpoint <YOUR_ENDPOINT> --project <PROJECT> --domain <DOMAIN> --builder remote
uv run train.py
```

When you run this, the pipeline:

1. **Builds containers** once and caches them for subsequent runs
2. **Prepares the dataset**: downloads CIFAR-10, generates occlusions, writes JSONL manifests; cached on subsequent runs with the same config
3. **Launches multi-node training**: provisions 2 × 4 L40s GPUs and starts the Elastic job
4. **Streams metrics to the live dashboard**: the Flyte Reports view starts updating as soon as the first step logs
5. **Runs evaluation**: a single-GPU task loads the adapter and runs inference, computing exact-match accuracy
6. **Generates the final report**: training curves, evaluation summary, and sample prediction cards appear in the Flyte UI

![Final Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/final_report.png)
![Predictions](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/predictions.png)

To resume a failed or interrupted run, uncomment the `resume_training_artifacts` line in `train.py` and point it to the recovery URI from the previous run. Training picks up from the last checkpoint with metrics history intact.

## Going further

**Adapting this to a different task.** The frozen backbone pattern transfers directly. Replace `QwenOcclusionDataset` and `prepare_occlusion_dataset` with your own data, update the prompt template, and adjust the dual loss if a pixel-level reconstruction term doesn't apply to your task. The multi-node Elastic setup, DeepSpeed Stage 2 config, recovery system, and live reporting are completely reusable.

**Using a larger Qwen model.** Change `DEFAULT_MODEL_NAME` to `Qwen/Qwen2.5-VL-7B-Instruct` or a larger variant. You may need to increase `memory` in `training_env` and reduce `per_device_batch_size`. The frozen backbone + adapter pattern becomes more valuable at larger scales where you're always training the same ~10,500-parameter adapter regardless of backbone size.

**Training keeps failing.** Add `retries=3` to the `@training_env.task` decorator. With the recovery callback uploading checkpoints after every validation epoch, Flyte automatically restarts training from the last checkpoint on transient failures. Spot instance preemptions and most hardware hiccups become non-events.

**Scaling to more nodes.** Increase `NUM_NODES` in `config.py`. The Elastic plugin, DeepSpeed strategy, and gradient accumulation calculation all adapt automatically. The recovery system is unchanged as each run still gets its own recovery URI.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/computer-vision/multimodal-retrieval-evaluation ===

# Multimodal retrieval evaluation

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/multimodal-retrieval-evaluation).

This tutorial builds an experiment framework for benchmarking **visual document retrieval** on the [ViDoRe benchmark](https://huggingface.co/vidore). The corpus is a set of PDF page *images* and the queries are plain-text questions — each retrieval method must find the page that answers a question from the raw image alone.

Three approaches are compared:

- **ColPali-v1.2** — patch-level multi-vector embeddings from a vision-language model (PaliGemma), scored with MaxSim late interaction. No OCR.
- **SigLIP-SO400M** — a single global embedding per page from Google's CLIP successor.
- **OCR + BM25** — a text-only baseline that OCRs each page with [docTR](https://github.com/mindee/doctr) on GPU and ranks with BM25.

Each experiment is one `ExperimentConfig`; the pipeline fans them out as concurrent Flyte tasks and returns a ranked comparison table with an interactive HTML report. It's a strong showcase of several Flyte features working together:

- **`ReusePolicy`** keeps warm GPU containers (with the ~7 GB ColPali weights already in VRAM) alive across task calls.
- A process-level **`DynamicBatcher`** aggregates queries from all concurrent search tasks into single GPU batches.
- **`cache="auto"`** so a model's index is built at most once per corpus and shared across experiments.
- **Typed Pydantic inputs/outputs** so every metric is stored alongside the exact config that produced it.

## Define the container image

One image serves every task. `unionai-reuse` provides the actor bridge required by `ReusePolicy`.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

The Python dependencies (ColPali, transformers, docTR, etc.) are declared in the `uv` script header at the top of the file.

## Define the task environments

Each model gets its own GPU environment so their warm-container pools scale independently. The ColPali and SigLIP environments use `ReusePolicy` to keep model weights resident; the driver coordinates orchestration, BM25, evaluation, and reporting.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

## Configuration and data types

An experiment is fully described by an `ExperimentConfig`. Because it's a Pydantic model, Flyte serializes it alongside every output.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

The corpus, queries, retrieval results, and metrics are likewise typed. Page images are stored as `flyte.io.File` handles in blob storage, so tasks read images directly rather than re-fetching over HTTP.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

## Loading, indexing, and search

`load_vidore_pages` downloads a ViDoRe subset and uploads each page image to blob storage (cached, with retries). Indexing tasks (`index_colpali`, `index_siglip`) encode every page into a `.npz` index, and the OCR task (`extract_page_texts`) produces the text baseline. These run on the GPU environments and are cached per corpus.

Search uses the `DynamicBatcher` so queries from all concurrent search-task invocations on a warm container are merged into a single GPU batch:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

> [!NOTE]
> The `DynamicBatcher` implementation lives in the `extras/` package next to the example. Run the script from the example directory so the import resolves.

## Run one experiment

`run_experiment` selects the right index/search path based on the runtime value of `config.model` — Flyte v2's dynamic execution means there's no static DAG to wire up. `flyte.group` wraps each experiment in a named span in the UI.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

## Compare experiments

The driver loads the dataset once, fans out across all configs with `asyncio.gather`, and emits an interactive Chart.js report in the Flyte UI. Experiments sharing a model reuse the cached index, so you only pay GPU time for new work.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

## Run the evaluation

This example has no secrets — datasets and model weights are pulled from public Hugging Face repositories. It does require GPUs, so run it remotely.

The experiment grid is defined in the entry point; adding a model or varying `top_k` is a one-line change:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "colpali-engine>=0.3.1",
#     "transformers>=4.41",
#     "sentencepiece>=0.2",
#     "torch>=2.0",
#     "pillow>=10",
#     "datasets>=2.18",
#     "rank-bm25>=0.2",
#     "numpy>=1.26",
#     "python-doctr[torch]>=0.8",
#     "pydantic>=2.0",
#     "flyte>=2.0.0",
# ]
# ///
"""
Multimodal Retrieval Evaluation Pipeline

This tutorial is an experiment framework for benchmarking visual document
retrieval approaches on the ViDoRe benchmark. Each experiment is defined by
an ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and
returns a ranked comparison table with an interactive HTML report.

The corpus is a set of PDF page images; queries are plain-text questions. Each
retrieval method must find the page that answers each question — no text is
provided to the model, only the raw image.

  ColPali-v1.2  — patch-level multi-vector embeddings from a VLM (PaliGemma).
                  No OCR. The model produces one vector per image patch
                  (~1024 per page). MaxSim late-interaction scoring finds the
                  best matching patch for each query token.

  SigLIP-SO400M — single global embedding per page from Google's 2023 CLIP
                  successor. One matrix multiply per query; fast and effective
                  but a single vector cannot localise fine-grained regions.

  OCR + BM25    — text-only baseline. doctr (GPU OCR) extracts text in
                  batches, BM25 matches keywords. Strong on text-dense pages;
                  fails on charts, tables, and figures where content is visual.

"""

import asyncio
import enum
import json
import math
import os
import tempfile
from functools import lru_cache
from io import BytesIO
from itertools import islice

import numpy as np
from PIL import Image as PILImage
from pydantic import BaseModel
from rank_bm25 import BM25Okapi

from extras import DynamicBatcher

import flyte
import flyte.report
from flyte.io import File

# ─────────────────────────────────────────────────────────────────────────────
# Environments
# ─────────────────────────────────────────────────────────────────────────────

# One Docker image for all tasks. The PEP 723 header defines Python deps.
# ca-certificates is required for HTTPS calls to HuggingFace and blob stores.
# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)
# {{/docs-fragment image}}

# GPU environment for ColPali image encoding and search.
#
# ReusePolicy keeps up to 3 warm GPU containers alive between task calls.
# Without it, every task invocation cold-starts a new container and downloads
# ColPali-v1.2 (~7 GB) from scratch. With it, the container — and the model
# weights already loaded into VRAM — is reused for the next task dispatch.
#
#   replicas=1      single warm container — all concurrent shard calls land
#                   here so they share one DynamicBatcher process
#   concurrency=8   up to 8 query-shard tasks run simultaneously on the
#                   container, all feeding the same DynamicBatcher queue
#   idle_ttl=120    keep alive 2 min after the last task finishes
#   scaledown_ttl=60 scale to zero after 1 min of complete inactivity
# {{docs-fragment envs}}
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)
# {{/docs-fragment envs}}

# ─────────────────────────────────────────────────────────────────────────────
# Configuration types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment config_types}}
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25

class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query
# {{/docs-fragment config_types}}

# ─────────────────────────────────────────────────────────────────────────────
# Data types
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment data_types}}
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query

class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True

class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst

class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int

class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics
# {{/docs-fragment data_types}}

class ComparisonReport(BaseModel):
    results: list[ExperimentResult]

    def best_by(self, metric: str = "recall_at_k") -> ExperimentResult:
        return max(self.results, key=lambda r: getattr(r.metrics, metric))

    def summary(self) -> str:
        header = f"{'Experiment':<30} {'Model':<18} {'Recall@K':>10} {'NDCG@K':>8} {'MRR':>7}"
        sep = "─" * len(header)
        rows = [header, sep]
        for r in sorted(self.results, key=lambda x: -x.metrics.recall_at_k):
            rows.append(
                f"{r.config.name:<30} "
                f"{r.config.model.value:<18} "
                f"{r.metrics.recall_at_k:>10.3f} "
                f"{r.metrics.ndcg_at_k:>8.3f} "
                f"{r.metrics.mrr:>7.3f}"
            )
        return "\n".join(rows)

# ─────────────────────────────────────────────────────────────────────────────
# Cached model loaders
# ─────────────────────────────────────────────────────────────────────────────
# These functions are at module level so they are shared across all tasks that
# run on the same warm container (via ReusePolicy). lru_cache(maxsize=1) means
# the model is loaded from disk/HuggingFace exactly once per container process
# and kept in GPU memory for every subsequent task dispatch to that container.

@lru_cache(maxsize=1)
def _colpali_model():
    """Load ColPali-v1.2 into GPU memory and cache the result.

    device_map= is the correct loading pattern for ColPali's PaliGemma
    backbone; it handles weight placement via accelerate. torch.compile is
    skipped — ColPali is GPU-compute-bound and the DynamicBatcher's cross-
    invocation batching is the primary GPU utilisation mechanism.
    """
    import torch
    from colpali_engine.models import ColPali, ColPaliProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        device_map=device,
    )
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    return model, processor, device

@lru_cache(maxsize=1)
def _siglip_model():
    """Load SigLIP SO400M into GPU memory, compile it, and cache the result.

    torch.compile (mode="reduce-overhead") fuses the vision and text encoder
    transformer layers into optimised CUDA kernels. As with ColPali, the
    compilation overhead is paid once per warm container lifetime.
    """
    import torch
    from transformers import AutoModel, AutoProcessor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained("google/siglip-so400m-patch14-224").to(device)
    if device == "cuda":
        model = torch.compile(model, mode="reduce-overhead")
    processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-224")
    return model, processor, device

@lru_cache(maxsize=1)
def _ocr_model():
    """Load the doctr OCR predictor onto GPU and cache it.

    doctr's ocr_predictor bundles a detection model (DBNet) and a
    recognition model (CRNN/SAR) into a single callable. pretrained=True
    downloads both model weights from doctr's model zoo on first use.
    """
    import torch
    from doctr.models import ocr_predictor

    predictor = ocr_predictor(pretrained=True)
    if torch.cuda.is_available():
        predictor = predictor.cuda()
    return predictor

# ─────────────────────────────────────────────────────────────────────────────
# Search batcher singletons
# ─────────────────────────────────────────────────────────────────────────────
# One DynamicBatcher per model, shared across all concurrent search task
# invocations on the same warm container (concurrency=3). Queries from every
# concurrent caller are aggregated into a single GPU batch, maximizing
# throughput compared to each invocation running its own forward pass.
#
# Initialised lazily on the first search call via double-checked locking and
# lives for the container's lifetime. The process_fn runs GPU work via
# asyncio.to_thread so the aggregation loop can continue collecting queries
# from other callers while the GPU processes the current batch.
#
# File is not hashable so alru_cache cannot be used here; module-level state
# with asyncio.Lock is the correct pattern.
#
# Assumption: index_colpali/index_siglip use cache="auto", so the same corpus
# always produces the same index File across all callers on this container. If
# the index file ever changed between calls, the batcher would silently continue
# using the corpus embeddings loaded from the first call.

_colpali_batcher: DynamicBatcher | None = None
_colpali_batcher_lock = asyncio.Lock()
_siglip_batcher: DynamicBatcher | None = None
_siglip_batcher_lock = asyncio.Lock()

async def _get_colpali_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level ColPali search batcher, creating it on first call."""
    global _colpali_batcher
    if _colpali_batcher is not None:
        return _colpali_batcher
    async with _colpali_batcher_lock:
        if _colpali_batcher is not None:
            return _colpali_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, n_patches, dim)
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _colpali_model()
        corpus_emb = corpus_emb.to(device, dtype=torch.float32)

        async def colpali_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                query_inputs = processor.process_queries([q.text for q in batch])
                query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
                with torch.no_grad():
                    query_embs = model(**query_inputs).float()  # (B, T, D)
                    query_chunk = 8
                    n_pages = corpus_emb.shape[0]
                    all_scores = torch.empty(len(batch), n_pages, device=device)
                    for start in range(0, len(batch), query_chunk):
                        chunk = query_embs[start : start + query_chunk]
                        all_scores[start : start + query_chunk] = (
                            torch.einsum("ctd,pjd->ctpj", chunk, corpus_emb)
                            .max(dim=3).values
                            .sum(dim=1)
                        )
                    sorted_indices = all_scores.argsort(dim=1, descending=True).cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            # Run GPU work in a thread so the event loop — and the batcher's
            # aggregation loop — remain unblocked while the GPU is busy.
            return await asyncio.to_thread(_gpu_work)

        batcher: DynamicBatcher[PageQuery, list[str]] = DynamicBatcher(
            process_fn=colpali_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _colpali_batcher = batcher
    return _colpali_batcher

async def _get_siglip_search_batcher(index_file: File) -> DynamicBatcher:
    """Return the process-level SigLIP search batcher, creating it on first call."""
    global _siglip_batcher
    if _siglip_batcher is not None:
        return _siglip_batcher
    async with _siglip_batcher_lock:
        if _siglip_batcher is not None:
            return _siglip_batcher

        import torch

        data = await _load_npz(index_file)
        corpus_emb = torch.from_numpy(data["embeddings"])  # (n_pages, dim), L2-normalised
        index_page_ids: list[str] = list(data["page_ids"])
        model, processor, device = _siglip_model()
        corpus_emb = corpus_emb.to(device)

        async def siglip_process_fn(batch: list[PageQuery]) -> list[list[str]]:
            def _gpu_work() -> list[list[str]]:
                text_inputs = processor(
                    text=[q.text for q in batch],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to(device)
                with torch.no_grad():
                    text_out = model.text_model(**text_inputs)
                    query_embs = text_out.pooler_output  # (B, dim)
                    query_embs = query_embs / query_embs.norm(dim=-1, keepdim=True)
                    scores_matrix = corpus_emb @ query_embs.T  # (n_pages, B)
                    sorted_indices = scores_matrix.argsort(dim=0, descending=True).T.cpu().tolist()
                return [[index_page_ids[j] for j in ranked] for ranked in sorted_indices]

            return await asyncio.to_thread(_gpu_work)

        batcher = DynamicBatcher(
            process_fn=siglip_process_fn,
            target_batch_cost=128,
            max_batch_size=128,
            batch_timeout_s=0.05,
            default_cost=1,
            prefetch_batches=2,
        )
        await batcher.start()
        _siglip_batcher = batcher
    return _siglip_batcher

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def _batches(items: list, batch_size: int):
    """Yield successive fixed-size batches from a list."""
    for start in range(0, len(items), batch_size):
        yield items[start : start + batch_size]

def _load_image_sync(f: File) -> PILImage.Image:
    """Blocking download + decode. Intended to be called from a thread pool."""
    with f.open_sync("rb") as fh:
        data = fh.read()
    return PILImage.open(BytesIO(data)).convert("RGB")

async def _load_image(f: File) -> PILImage.Image:
    """Download and decode a page image in a thread-pool worker.

    asyncio.to_thread runs _load_image_sync in a real OS thread so that
    blocking network I/O can overlap with GPU-bound forward passes when
    images are pre-submitted via loop.run_in_executor before the GPU kernel.
    """
    return await asyncio.to_thread(_load_image_sync, f)

async def _load_npz(index_file: File) -> np.lib.npyio.NpzFile:
    """Download an index File to a local temp path and open with np.load."""
    with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as tmp:
        async with index_file.open("rb") as fh:
            tmp.write(bytes(await fh.read()))
        return np.load(tmp.name)

def _dcg(relevances: list[int]) -> float:
    return sum(rel / math.log2(rank + 2) for rank, rel in enumerate(relevances))

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — data loading
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(cache="auto", retries=3)
async def load_vidore_pages(subset: str = "docvqa", max_pages: int = 200) -> PageDataset:
    """
    Load a ViDoRe benchmark subset and store page images in Flyte's blob store.

    Supports two dataset formats:

    Legacy (subsampled) — single 'test' split with one row per (query, page)
    pair; fields: image, query, image_filename. streaming=True reads only the
    rows requested via islice — no full-shard download.
    Datasets: vidore/docvqa_test_subsampled, vidore/infovqa_test_subsampled

    V3 — separate corpus / queries / qrels splits following the BEIR retrieval
    benchmark format. corpus contains page images; queries contains question
    text; qrels maps query IDs to relevant corpus page IDs (many-to-many).
    Datasets: vidore/vidore_v3_finance_en  (~2 942 pages, 1 854 queries)

    The first call uploads page images to Flyte's blob store and caches the
    PageDataset; every subsequent call with the same arguments returns the
    cached result instantly. retries=3 guards against transient HuggingFace
    network failures.

    Available subsets: "docvqa", "infovqa", "vidore_v3_finance_en"
    """
    from datasets import load_dataset

    subset_map = {
        "docvqa": "vidore/docvqa_test_subsampled",
        "infovqa": "vidore/infovqa_test_subsampled",
        "vidore_v3_finance_en": "vidore/vidore_v3_finance_en",
    }
    dataset_name = subset_map.get(subset, f"vidore/{subset}_test_subsampled")

    # V3 datasets ship with separate corpus / queries / qrels splits.
    _V3_SUBSETS = {"vidore_v3_finance_en"}

    if subset in _V3_SUBSETS:
        # ── V3 format ─────────────────────────────────────────────────────────
        # corpus / queries / qrels are HuggingFace configs (name=), not splits.
        # corpus uses streaming=True so images are decoded one at a time —
        # loading all 2 942 rows eagerly would hold gigabytes of PIL images in
        # the driver's RAM simultaneously. qrels and queries are text-only and
        # small enough to load fully into memory.
        corpus_ds = load_dataset(dataset_name, name="corpus", split="test", streaming=True)
        qrels_ds = load_dataset(dataset_name, name="qrels", split="test")
        queries_ds = load_dataset(dataset_name, name="queries", split="test")

        # Normalise field names — V3 follows BEIR convention (hyphenated ids).
        def _col(ds, *candidates):
            cols = set(ds.column_names)
            for c in candidates:
                if c in cols:
                    return c
            raise KeyError(f"None of {candidates} found in columns {cols}")

        corpus_id_col = _col(corpus_ds, "corpus-id", "corpus_id", "id", "_id")
        query_id_col = _col(queries_ds, "query-id", "query_id", "id", "_id")
        query_text_col = _col(queries_ds, "query", "text")
        qrel_qid_col = _col(qrels_ds, "query-id", "query_id")
        qrel_cid_col = _col(qrels_ds, "corpus-id", "corpus_id")

        # Slice corpus to max_pages, upload each image to Flyte blob store.
        page_ids: list[str] = []
        page_files: list[File] = []
        corpus_id_to_page_id: dict[str, str] = {}

        for i, row in enumerate(islice(corpus_ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            cid = str(row[corpus_id_col])
            page_id = f"{subset}_{i:04d}"
            with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                tmp_path = f.name
                img.convert("RGB").save(tmp_path, format="JPEG")
            del img  # free PIL memory before upload
            page_file = await File.from_local(tmp_path)
            os.unlink(tmp_path)
            corpus_id_to_page_id[cid] = page_id
            page_ids.append(page_id)
            page_files.append(page_file)

        # Build query_id → relevant page_id from qrels (first match wins).
        # Only keep relevance judgements whose corpus page is in our slice.
        qrel_map: dict[str, str] = {}
        for row in qrels_ds:
            qid = str(row[qrel_qid_col])
            cid = str(row[qrel_cid_col])
            if cid in corpus_id_to_page_id and qid not in qrel_map:
                qrel_map[qid] = corpus_id_to_page_id[cid]

        # Collect queries that have at least one relevant page in our slice.
        queries: list[PageQuery] = []
        for row in queries_ds:
            qid = str(row[query_id_col])
            if qid not in qrel_map:
                continue
            queries.append(
                PageQuery(
                    query_id=qid,
                    text=str(row[query_text_col]),
                    relevant_page_id=qrel_map[qid],
                )
            )

    else:
        # ── Legacy format ─────────────────────────────────────────────────────
        # Single 'test' split with one row per (query, page) pair.
        ds = load_dataset(dataset_name, split="test", streaming=True)

        page_ids = []
        page_files = []
        queries = []
        seen_pages: dict[str, str] = {}  # image_filename → page_id

        for i, row in enumerate(islice(ds, max_pages)):
            img = row.get("image")
            if not isinstance(img, PILImage.Image):
                continue
            filename: str = row.get("image_filename") or f"page_{i}"
            query_text: str = row.get("query", "")
            if not query_text:
                continue

            # Each unique page is uploaded exactly once; multiple queries may
            # share the same page (same image_filename).
            if filename not in seen_pages:
                page_id = f"{subset}_{len(page_ids):04d}"
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
                    tmp_path = f.name
                    img.convert("RGB").save(tmp_path, format="JPEG")
                del img  # free PIL memory before upload
                page_file = await File.from_local(tmp_path)
                os.unlink(tmp_path)
                seen_pages[filename] = page_id
                page_ids.append(page_id)
                page_files.append(page_file)
            else:
                page_id = seen_pages[filename]

            queries.append(
                PageQuery(
                    query_id=f"q{i:04d}",
                    text=query_text,
                    relevant_page_id=page_id,
                )
            )

    print(f"Loaded {len(page_ids)} unique pages, {len(queries)} queries", flush=True)
    return PageDataset(page_ids=page_ids, page_files=page_files, queries=queries)

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — indexing
# ─────────────────────────────────────────────────────────────────────────────

@colpali_indexer.task(cache="auto", retries=2)
async def index_colpali(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with ColPali-v1.2 and save the multi-vector index.

    ColPali skips OCR entirely. It feeds the raw page image into PaliGemma
    (a vision-language model) and produces one embedding vector per image
    patch — roughly 1,024 patches per page, each of dimension 128.

    _colpali_model() is an lru_cache'd loader. On a cold container, it
    downloads and loads the model once. On a warm container (kept alive by
    ReusePolicy), it returns the already-loaded model instantly from cache —
    no repeated ~7 GB download.

    The index is stored as a .npz file in Flyte's blob store:
      embeddings — float32, shape (n_pages, n_patches, dim)
      page_ids   — matching page ID strings

    cache="auto" + retries=2: the result is stored permanently on success;
    transient failures (e.g. HuggingFace rate limits) are retried twice.
    """
    import torch

    model, processor, device = _colpali_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 4))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor.process_images(images)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            emb = model(**inputs)  # (batch, n_patches, dim)

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"ColPali: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, n_patches, dim)
    out_path = os.path.join(tempfile.gettempdir(), "colpali_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@siglip_indexer.task(cache="auto", retries=2)
async def index_siglip(page_ids: list[str], page_files: list[File]) -> File:
    """
    Encode every page with SigLIP SO400M and save the single-vector index.

    SigLIP (2023) is Google's successor to CLIP, trained with sigmoid loss
    instead of softmax — avoiding the normalisation bottleneck that limits
    CLIP's scalability. Produces one global embedding per page.

    _siglip_model() caches the model across warm container reuses.

    The index is stored as a .npz file:
      embeddings — float32, shape (n_pages, dim), L2-normalised
      page_ids   — matching page ID strings
    """
    import torch

    model, processor, device = _siglip_model()

    loop = asyncio.get_running_loop()
    batches = list(_batches(page_files, 8))
    n_batches = len(batches)

    # Submit the first batch to the thread pool before entering the loop so
    # that downloads are already in flight when we first await them.
    prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[0]]

    all_embeddings: list[np.ndarray] = []
    for batch_idx in range(n_batches):
        images = list(await asyncio.gather(*prefetch))

        # Submit next batch downloads immediately — OS threads run these in
        # parallel with the GPU forward pass below.
        if batch_idx + 1 < n_batches:
            prefetch = [loop.run_in_executor(None, _load_image_sync, f) for f in batches[batch_idx + 1]]

        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        with torch.no_grad():
            outputs = model.vision_model(**inputs)
            emb = outputs.pooler_output  # (batch, dim)
            emb = emb / emb.norm(dim=-1, keepdim=True)  # L2 normalise

        all_embeddings.append(emb.cpu().float().numpy())
        print(f"SigLIP: indexed batch {batch_idx + 1}/{n_batches}", flush=True)

    embeddings = np.concatenate(all_embeddings, axis=0)  # (n_pages, dim)
    out_path = os.path.join(tempfile.gettempdir(), "siglip_index.npz")
    np.savez(out_path, embeddings=embeddings, page_ids=np.array(page_ids))
    return await File.from_local(out_path)

@ocr_engine.task(cache="auto")
async def extract_page_texts(page_files: list[File]) -> list[str]:
    """
    OCR every page with doctr on GPU to produce a text-only baseline.

    doctr bundles DBNet (detection) + CRNN/SAR (recognition) into a single
    callable predictor. Pages are downloaded in parallel then fed in batches
    of ocr_batch_size. asyncio.to_thread keeps the event loop unblocked
    during GPU inference.

    Result structure: result.pages[i].blocks[j].lines[k].words[l].value

    Cached: the same corpus is OCR'd at most once across all experiments
    that use the OCR+BM25 backend.
    """
    import gc

    predictor = _ocr_model()

    # Process in batches: download each batch just-in-time so only
    # ocr_batch_size images are in memory at once instead of all 2 000.
    ocr_batch_size = 8
    total = len(page_files)
    texts: list[str] = []
    for start in range(0, total, ocr_batch_size):
        batch_files = page_files[start : start + ocr_batch_size]
        batch_images = list(
            await asyncio.gather(*[asyncio.to_thread(_load_image_sync, f) for f in batch_files])
        )
        batch_np = [np.array(img) for img in batch_images]
        del batch_images
        result = await asyncio.to_thread(predictor, batch_np)
        del batch_np
        for page_output in result.pages:
            texts.append(
                "\n".join(
                    " ".join(word.value for word in line.words)
                    for block in page_output.blocks
                    for line in block.lines
                )
            )
        del result
        gc.collect()
        print(f"OCR: processed {min(start + ocr_batch_size, total)}/{total} pages", flush=True)

    return texts

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — search
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment search_colpali}}
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]
# {{/docs-fragment search_colpali}}

@siglip_indexer.task
async def search_siglip(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using SigLIP cosine similarity via DynamicBatcher.

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_siglip invocations on the
    same warm container (concurrency=3) into a single GPU batch.

    SigLIP's single-vector embeddings make full vectorisation safe —
    the scores matrix (n_pages x n_queries) is small enough to materialise
    in one GPU call regardless of batch size.
    """
    batcher = await _get_siglip_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

@driver.task
async def search_bm25(
    page_texts: list[str],
    page_ids: list[str],
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using BM25 over OCR'd text.

    The standard keyword-based baseline. No GPU required; strong on
    text-dense pages, weak on visual content that Tesseract cannot read.
    """
    tokenized = [text.lower().split() for text in page_texts]
    bm25 = BM25Okapi(tokenized)

    results: list[RetrievalResult] = []
    for q in queries:
        scores = bm25.get_scores(q.text.lower().split())
        ranked = sorted(range(len(page_ids)), key=lambda i: -scores[i])[:top_k]
        results.append(
            RetrievalResult(
                query_id=q.query_id,
                ranked_page_ids=[page_ids[i] for i in ranked],
            )
        )
    return results

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — evaluation
# ─────────────────────────────────────────────────────────────────────────────

@driver.task
async def evaluate(
    results: list[RetrievalResult],
    ground_truth: list[PageQuery],
    k: int,
) -> Metrics:
    """
    Compute Recall@K, NDCG@K, and MRR for a single retrieval model.

    Recall@K  — was the correct page in the top-K results?
    NDCG@K    — normalised discounted cumulative gain; rewards earlier hits.
    MRR       — mean reciprocal rank of the first correct result.

    All three are averaged over all queries. Higher is better.
    """
    gt_map = {q.query_id: q.relevant_page_id for q in ground_truth}
    recall_vals, ndcg_vals, mrr_vals = [], [], []

    for r in results:
        relevant = gt_map.get(r.query_id, "")
        top = r.ranked_page_ids[:k]

        recall_vals.append(1.0 if relevant in top else 0.0)

        rels = [1 if pid == relevant else 0 for pid in top]
        idcg = _dcg([1])  # ideal: correct page at rank 1
        ndcg_vals.append(_dcg(rels) / idcg if idcg > 0 else 0.0)

        rr = 0.0
        for rank, pid in enumerate(r.ranked_page_ids, start=1):
            if pid == relevant:
                rr = 1.0 / rank
                break
        mrr_vals.append(rr)

    return Metrics(
        recall_at_k=float(np.mean(recall_vals)),
        ndcg_at_k=float(np.mean(ndcg_vals)),
        mrr=float(np.mean(mrr_vals)),
        k=k,
    )

# ─────────────────────────────────────────────────────────────────────────────
# Tasks — report
# ─────────────────────────────────────────────────────────────────────────────

@driver.task(report=True)
async def generate_report(report: ComparisonReport) -> None:
    """
    Emit an interactive HTML report visible in the Flyte UI.

    report=True marks this task as a reporting task. Flyte renders the HTML
    returned via flyte.report.replace.aio() directly in the execution detail
    page — no separate dashboard or export step required.

    The report contains:
      - Summary cards: experiment count, best model, best Recall@K.
      - Grouped bar chart: Recall@K, NDCG@K, MRR side-by-side per experiment.
      - Ranked results table with all three metrics.
    """
    sorted_results = sorted(report.results, key=lambda r: -r.metrics.recall_at_k)
    best = sorted_results[0]

    labels = [r.config.name for r in sorted_results]
    recall_vals = [r.metrics.recall_at_k for r in sorted_results]
    ndcg_vals = [r.metrics.ndcg_at_k for r in sorted_results]
    mrr_vals = [r.metrics.mrr for r in sorted_results]

    table_rows = "".join(
        f"""
        <tr>
          <td>{r.config.name}</td>
          <td>{r.config.model.value}</td>
          <td>{r.metrics.recall_at_k:.3f}</td>
          <td>{r.metrics.ndcg_at_k:.3f}</td>
          <td>{r.metrics.mrr:.3f}</td>
          <td>{r.metrics.k}</td>
        </tr>"""
        for r in sorted_results
    )

    html = f"""<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Visual Document Retrieval — Results</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    * {{ box-sizing: border-box; margin: 0; padding: 0; }}
    body {{
      font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
      background: #f0f2f5; color: #222; padding: 24px;
    }}
    h1 {{ font-size: 1.6em; margin-bottom: 4px; }}
    .subtitle {{ color: #666; margin-bottom: 24px; font-size: 0.95em; }}
    .cards {{
      display: flex; gap: 16px; flex-wrap: wrap; margin-bottom: 28px;
    }}
    .card {{
      background: #fff; border-radius: 10px; padding: 18px 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); min-width: 160px;
    }}
    .card-value {{ font-size: 1.9em; font-weight: 700; color: #4f46e5; }}
    .card-label {{ font-size: 0.8em; color: #888; text-transform: uppercase;
                   letter-spacing: .04em; margin-top: 2px; }}
    .chart-box {{
      background: #fff; border-radius: 10px; padding: 24px;
      box-shadow: 0 1px 4px rgba(0,0,0,.08); margin-bottom: 28px;
    }}
    .chart-box h2 {{ font-size: 1em; margin-bottom: 16px; color: #444; }}
    table {{ width: 100%; border-collapse: collapse; font-size: 0.9em; }}
    th {{
      background: #4f46e5; color: #fff; padding: 10px 14px;
      text-align: left; font-weight: 600;
    }}
    td {{ padding: 9px 14px; border-bottom: 1px solid #eee; }}
    tr:hover td {{ background: #f8f8ff; }}
    tr:first-child td {{ font-weight: 600; }}
  </style>
</head>
<body>
  <h1>Visual Document Retrieval — Experiment Comparison</h1>
  <p class="subtitle">ViDoRe benchmark &middot; {len(report.results)} experiment(s)</p>

  <div class="cards">
    <div class="card">
      <div class="card-value">{len(report.results)}</div>
      <div class="card-label">Experiments</div>
    </div>
    <div class="card">
      <div class="card-value">{best.config.name}</div>
      <div class="card-label">Best by Recall@K</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.recall_at_k:.3f}</div>
      <div class="card-label">Best Recall@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.ndcg_at_k:.3f}</div>
      <div class="card-label">Best NDCG@{best.metrics.k}</div>
    </div>
    <div class="card">
      <div class="card-value">{best.metrics.mrr:.3f}</div>
      <div class="card-label">Best MRR</div>
    </div>
  </div>

  <div class="chart-box">
    <h2>Metrics by Experiment</h2>
    <canvas id="metricsChart" height="100"></canvas>
  </div>

  <div class="chart-box">
    <h2>Ranked Results</h2>
    <table>
      <thead>
        <tr>
          <th>Experiment</th><th>Model</th>
          <th>Recall@K</th><th>NDCG@K</th><th>MRR</th><th>K</th>
        </tr>
      </thead>
      <tbody>{table_rows}</tbody>
    </table>
  </div>

  <script>
    new Chart(document.getElementById('metricsChart'), {{
      type: 'bar',
      data: {{
        labels: {json.dumps(labels)},
        datasets: [
          {{
            label: 'Recall@K',
            data: {json.dumps(recall_vals)},
            backgroundColor: 'rgba(79,70,229,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'NDCG@K',
            data: {json.dumps(ndcg_vals)},
            backgroundColor: 'rgba(16,185,129,0.75)',
            borderRadius: 4
          }},
          {{
            label: 'MRR',
            data: {json.dumps(mrr_vals)},
            backgroundColor: 'rgba(245,158,11,0.75)',
            borderRadius: 4
          }}
        ]
      }},
      options: {{
        responsive: true,
        plugins: {{ legend: {{ position: 'top' }} }},
        scales: {{
          y: {{ beginAtZero: true, max: 1.0,
               title: {{ display: true, text: 'Score' }} }}
        }}
      }}
    }});
  </script>
</body>
</html>"""

    await flyte.report.replace.aio(html)
    await flyte.report.flush.aio()

# ─────────────────────────────────────────────────────────────────────────────
# Experiment orchestration
# ─────────────────────────────────────────────────────────────────────────────

# {{docs-fragment run_experiment}}
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)
# {{/docs-fragment run_experiment}}

# {{docs-fragment compare_experiments}}
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report
# {{/docs-fragment compare_experiments}}

# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    flyte.init_from_config()

    # Define the experiment grid. Each ExperimentConfig is one point in the
    # design space. Adding a new model or varying top_k is one line here —
    # no task code changes required.
    #
    # ColPali appears twice with different top_k values. The cache ensures
    # index_colpali runs only once and both experiments share that result.
    # {{docs-fragment grid}}
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]
    # {{/docs-fragment grid}}

    run = flyte.with_runcontext().run(
        compare_experiments,
        configs=configs,
        subset="vidore_v3_finance_en",
        max_pages=2000,
    )
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/multimodal-retrieval-evaluation/retrieval_eval.py*

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/multimodal-retrieval-evaluation):

```
cd v2/tutorials/multimodal-retrieval-evaluation
python retrieval_eval.py
```

When the run completes, open the `generate_report` task in the UI to see the summary cards, the grouped Recall@K / NDCG@K / MRR bar chart, and the ranked results table.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/computer-vision/detr-object-detection ===

# RT-DETR object detection

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/detr_object_detection).

This tutorial fine-tunes [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r18vd) on a custom COCO-format dataset from HuggingFace. The pipeline downloads and splits the data, fine-tunes the detector with live training charts in Flyte reports, evaluates COCO mAP on a validation split, and renders a side-by-side inference demo with ground-truth and predicted bounding boxes.

Flyte highlights:

- **Cached dataset preparation** so re-runs skip the HuggingFace download.
- **Live training reports** with loss curves and optional periodic mAP checkpoints.
- **GPU evaluation and demo tasks** that stream annotated images into the UI.

## Define the task environments

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "torchvision>=0.24.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "huggingface_hub>=0.24.0",
#    "datasets>=3.0.0",
#    "pillow>=10.0.0",
#    "albumentations>=1.4.0",
#    "torchmetrics>=1.4.0",
#    "pycocotools>=2.0.7",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import asyncio
import base64
import io
import json
import logging
import os
import random
import shutil
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="detr-object-detection", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="detr-object-detection-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="detr-object-detection-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Report styling — shared CSS for all task reports
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #16213e; border-bottom: 2px solid #0f3460; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #0f3460; margin-top: 20px; }
  .report .card { background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #e9ecef; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0f3460; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0f3460; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dee2e6; }
  .report tr:nth-child(even) { background: #f8f9fa; }
  .report .highlight { color: #0f3460; font-weight: 700; }
  .report .note { background: #fff3cd; border-left: 4px solid #ffc107; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .img-pair { display: flex; gap: 12px; margin: 16px 0; flex-wrap: wrap; }
  .report .img-pair > div { flex: 1; min-width: 300px; }
  .report .img-pair img { width: 100%; border-radius: 6px; border: 1px solid #dee2e6; }
  .report .img-pair .gt-label { color: #5a7db5; font-weight: 600; }
  .report .img-pair .pred-label { color: #06d6a0; font-weight: 600; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d4edda; color: #155724; }
  .report .badge-info { background: #d1ecf1; color: #0c5460; }
  .report .chart-container { background: #fff; border: 1px solid #dee2e6; border-radius: 8px; padding: 16px; margin: 16px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers — lightweight charts without matplotlib
# ------------------------------------------------------------------

def _make_line_chart(
    data: list[dict],
    x_key: str,
    y_keys: list[str],
    title: str = "",
    x_label: str = "",
    y_label: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    y_max_cap: float | None = None,
    x_range_override: tuple[float, float] | None = None,
    y_display_names: dict[str, str] | None = None,
) -> str:
    """Generate an SVG line chart from a list of dicts.

    Args:
        data: List of dicts, each with x_key and y_keys values.
        x_key: Key for x-axis values.
        y_keys: Keys for y-axis series to plot.
        title: Chart title.
        x_label: X-axis label.
        y_label: Y-axis label.
        colors: Colors for each series (defaults to a built-in palette).
        width: SVG width in pixels.
        height: SVG height in pixels.
        y_max_cap: If set, cap the y-axis at this value (e.g. 1.0 for mAP).
        x_range_override: If set, force the x-axis to this (min, max) range.

    Returns:
        SVG string.
    """

    default_colors = ["#5a7db5", "#0f3460", "#06d6a0", "#ffc107", "#6c757d"]
    colors = colors or default_colors

    # Chart area margins
    ml, mr, mt, mb = 60, 20, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [d[x_key] for d in data] if data else []
    if x_range_override:
        x_min, x_max = x_range_override
    elif x_vals:
        x_min, x_max = min(x_vals), max(x_vals)
    else:
        x_min, x_max = 0, 1
    x_range = x_max - x_min or 1

    # Compute y range across all series
    all_y = []
    for key in y_keys:
        all_y.extend(d[key] for d in data if key in d)
    y_min = min(all_y) if all_y else 0
    y_max = max(all_y) if all_y else 1
    y_pad = (y_max - y_min) * 0.1 or 0.1
    y_min_plot = max(0, y_min - y_pad)
    y_max_plot = y_max + y_pad
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap)
    y_range = y_max_plot - y_min_plot or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min_plot) / y_range * ch

    # Build SVG
    lines = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        # Background
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines (5 horizontal)
    for i in range(6):
        y_tick = y_min_plot + y_range * i / 5
        py = sy(y_tick)
        lines.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        lines.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.3f}</text>'
        )

    # Axes
    lines.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#adb5bd" stroke-width="1.5"/>'
    )
    lines.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#adb5bd" stroke-width="1.5"/>'
    )

    # X-axis ticks
    if x_vals:
        n_x_ticks = min(len(data), 10)
        step = max(1, len(data) // n_x_ticks)
        for i in range(0, len(data), step):
            px = sx(x_vals[i])
            lines.append(
                f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
                f'font-size="11" fill="#6c757d">{x_vals[i]:.0f}</text>'
            )
    else:
        # Empty chart — generate evenly spaced ticks from x range
        for i in range(6):
            x_tick = x_min + x_range * i / 5
            px = sx(x_tick)
            lines.append(
                f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
                f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
            )

    # Plot each series
    if not data:
        # Empty chart placeholder
        lines.append(
            f'<text x="{ml + cw / 2}" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="13" fill="#adb5bd" font-style="italic">Waiting for data...</text>'
        )
    for si, key in enumerate(y_keys):
        color = colors[si % len(colors)]
        points = [(sx(d[x_key]), sy(d[key])) for d in data if key in d]
        if not points:
            continue
        # Draw line if we have 2+ points (dash odd series for visibility)
        if len(points) >= 2:
            path_d = f"M {points[0][0]:.1f},{points[0][1]:.1f}"
            for px, py in points[1:]:
                path_d += f" L {px:.1f},{py:.1f}"
            dash = ' stroke-dasharray="6,3"' if si % 2 == 1 else ""
            lines.append(
                f'<path d="{path_d}" fill="none" stroke="{color}" '
                f'stroke-width="2" stroke-linejoin="round"{dash}/>'
            )
        # Always show dots for sparse data (including single points)
        if len(points) <= 30:
            for px, py in points:
                lines.append(
                    f'<circle cx="{px:.1f}" cy="{py:.1f}" r="3" fill="{color}"/>'
                )

    # Title
    if title:
        lines.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Axis labels
    if x_label:
        lines.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        lines.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    # Legend
    names = y_display_names or {}
    if len(y_keys) > 1:
        lx = ml + 10
        for si, key in enumerate(y_keys):
            color = colors[si % len(colors)]
            ly = mt + 14 + si * 18
            lines.append(
                f'<rect x="{lx}" y="{ly - 6}" width="12" height="12" '
                f'rx="2" fill="{color}"/>'
            )
            label = names.get(key, key)
            lines.append(
                f'<text x="{lx + 16}" y="{ly + 4}" font-size="11" '
                f'fill="#1a1a2e">{label}</text>'
            )

    lines.append("</svg>")
    return "\n".join(lines)

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    y_max_cap: float | None = None,
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels for x-axis.
        series: Dict mapping series name to list of values (same length as labels).
        title: Chart title.
        colors: Colors for each series.
        width: SVG width.
        height: SVG height.
        y_max_cap: If set, cap the y-axis at this value (e.g. 1.0 for mAP).

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#adb5bd", "#0f3460", "#06d6a0", "#5a7db5"]
    colors = colors or default_colors

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    lines_svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        lines_svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        lines_svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.3f}</text>'
        )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            lines_svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            # Value label on top of bar
            lines_svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="10" fill="#1a1a2e">'
                f'{val:.3f}</text>'
            )
        # Group label
        lines_svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 18}" '
            f'text-anchor="middle" font-size="11" fill="#6c757d">{label}</text>'
        )

    # Title
    if title:
        lines_svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Legend
    lx = ml + cw - len(series) * 100
    for si, name in enumerate(series):
        color = colors[si % len(colors)]
        lines_svg.append(
            f'<rect x="{lx + si * 100}" y="{mt + ch + 35}" width="12" '
            f'height="12" rx="2" fill="{color}"/>'
        )
        lines_svg.append(
            f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 46}" font-size="11" '
            f'fill="#1a1a2e">{name}</text>'
        )

    lines_svg.append("</svg>")
    return "\n".join(lines_svg)

# ------------------------------------------------------------------
# Task 1: Prepare dataset — download COCO JSON + images, split train/val
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def prepare_data(
    dataset_repo: str = "sagecodes/union_flyte_swag_object_detection",
    annotations_path: str = "swag/train.json",
    images_subdir: str = "swag/images",
    val_fraction: float = 0.2,
    seed: int = 42,
) -> flyte.io.Dir:
    """Download a COCO-format dataset from HF and split into train/val."""
    from huggingface_hub import snapshot_download

    log.info(f"Downloading dataset: {dataset_repo}")
    local_repo = snapshot_download(
        repo_id=dataset_repo,
        repo_type="dataset",
    )

    ann_file = os.path.join(local_repo, annotations_path)
    img_root = os.path.join(local_repo, images_subdir)

    with open(ann_file) as f:
        coco = json.load(f)

    images = coco["images"]
    annotations = coco["annotations"]
    categories = coco["categories"]

    log.info(
        f"Loaded {len(images)} images, {len(annotations)} annotations, "
        f"{len(categories)} categories"
    )
    log.info(f"Raw category ids: {sorted({c['id'] for c in categories})}")
    log.info(
        f"Raw annotation category_ids (unique): "
        f"{sorted({a['category_id'] for a in annotations})}"
    )

    # Remap category ids to contiguous 0..N-1 — required because HF object
    # detection models size their classifier head to len(id2label) and treat
    # class labels as direct indices into that head. Any gap or 1-indexed id
    # causes an IndexKernel OOB inside the focal-loss scatter.
    #
    # Build the remap from the UNION of ids declared in `categories` and ids
    # actually used in `annotations` — some datasets have orphaned annotations
    # referencing categories that aren't declared (this one does).
    declared_ids = {c["id"] for c in categories}
    used_ids = {a["category_id"] for a in annotations}
    orphans = used_ids - declared_ids
    if orphans:
        log.warning(
            f"Annotations reference undeclared category ids {sorted(orphans)} — "
            f"adding stub categories."
        )

    all_cat_ids = sorted(declared_ids | used_ids)
    id_remap = {old: new for new, old in enumerate(all_cat_ids)}
    existing_names = {c["id"]: c["name"] for c in categories}
    categories = [
        {"id": id_remap[old], "name": existing_names.get(old, f"category_{old}")}
        for old in all_cat_ids
    ]
    annotations = [
        {**a, "category_id": id_remap[a["category_id"]]} for a in annotations
    ]
    log.info(f"Remapped category ids: {id_remap}")
    log.info(f"Final categories: {categories}")

    # Split by image id
    rng = random.Random(seed)
    img_ids = [im["id"] for im in images]
    rng.shuffle(img_ids)
    n_val = max(1, int(len(img_ids) * val_fraction))
    val_ids = set(img_ids[:n_val])
    train_ids = set(img_ids[n_val:])

    def filter_coco(keep_ids: set) -> dict:
        return {
            "info": coco.get("info", {}),
            "categories": categories,
            "images": [im for im in images if im["id"] in keep_ids],
            "annotations": [a for a in annotations if a["image_id"] in keep_ids],
        }

    train_coco = filter_coco(train_ids)
    val_coco = filter_coco(val_ids)

    log.info(
        f"Split: {len(train_coco['images'])} train / {len(val_coco['images'])} val images"
    )

    # Pack output dir: images/ + train.json + val.json
    out_dir = tempfile.mkdtemp(prefix="coco_split_")
    out_img = os.path.join(out_dir, "images")
    shutil.copytree(img_root, out_img)

    with open(os.path.join(out_dir, "train.json"), "w") as f:
        json.dump(train_coco, f)
    with open(os.path.join(out_dir, "val.json"), "w") as f:
        json.dump(val_coco, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Helpers — torch Dataset wrapping COCO JSON
# ------------------------------------------------------------------

def _build_torch_dataset(coco_path: str, images_root: str, augment: bool):
    """Build a torch Dataset that yields {image, target} for the HF image processor."""
    import albumentations as A
    import numpy as np
    from PIL import Image
    from torch.utils.data import Dataset

    with open(coco_path) as f:
        coco = json.load(f)

    images_by_id = {im["id"]: im for im in coco["images"]}
    anns_by_image: dict[int, list] = {}
    for a in coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    image_ids = list(images_by_id.keys())

    # NOTE: we deliberately don't resize here — the HF image processor handles
    # resize+pad. Augmentation only.
    if augment:
        transform = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.1),
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=30, val_shift_limit=20, p=0.4),
                A.Rotate(limit=15, border_mode=0, p=0.4),
                A.RandomScale(scale_limit=0.2, p=0.4),
                A.GaussianBlur(blur_limit=(3, 5), p=0.2),
                A.GaussNoise(p=0.2),
            ],
            bbox_params=A.BboxParams(
                format="coco",
                label_fields=["category"],
                min_area=4,
                min_visibility=0.1,
                clip=True,
            ),
        )
    else:
        transform = A.Compose(
            [A.NoOp()],
            bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
        )

    class CocoDataset(Dataset):
        def __len__(self) -> int:
            return len(image_ids)

        def __getitem__(self, idx: int):
            img_id = image_ids[idx]
            meta = images_by_id[img_id]
            img_path = os.path.join(images_root, os.path.basename(meta["file_name"]))
            if not os.path.exists(img_path):
                img_path = os.path.join(images_root, meta["file_name"])
            image = np.array(Image.open(img_path).convert("RGB"))

            anns = anns_by_image.get(img_id, [])
            bboxes = [a["bbox"] for a in anns]
            categories = [a["category_id"] for a in anns]

            out = transform(image=image, bboxes=bboxes, category=categories)
            image_t = out["image"]
            bboxes_t = out["bboxes"]
            categories_t = out["category"]

            target_anns = []
            for bb, cat in zip(bboxes_t, categories_t):
                x, y, w, h = bb
                target_anns.append(
                    {
                        "image_id": img_id,
                        "category_id": int(cat),
                        "bbox": [float(x), float(y), float(w), float(h)],
                        "area": float(w * h),
                        "iscrowd": 0,
                    }
                )

            return {
                "image": image_t,
                "target": {"image_id": img_id, "annotations": target_anns},
            }

    return CocoDataset(), coco["categories"]

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    epochs: int = 30,
    lr: float = 5e-5,
    batch_size: int = 4,
    weight_decay: float = 1e-4,
    eval_every_n_epochs: int | None = None,
) -> flyte.io.Dir:
    """Fine-tune RT-DETR (or any HuggingFace object-detection model) on COCO data."""
    import torch
    from transformers import (
        AutoImageProcessor,
        AutoModelForObjectDetection,
        Trainer,
        TrainerCallback,
        TrainingArguments,
    )

    log.info(f"Training: model={model_name}")
    await flyte.report.replace.aio(_wrap_report(
        f"<h2>Loading model...</h2><p>{model_name}</p>"
        f"<p>Preparing dataset and initializing weights...</p>"
    ), do_flush=True)

    # -- Load data --
    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    train_json = os.path.join(data_path, "train.json")

    with open(train_json) as f:
        categories = json.load(f)["categories"]
    id2label = {c["id"]: c["name"] for c in categories}
    label2id = {v: k for k, v in id2label.items()}

    train_ds, _ = _build_torch_dataset(train_json, images_root, augment=True)
    log.info(f"Train examples: {len(train_ds)} | Categories: {id2label}")

    # -- Optionally load val set for periodic mAP evaluation --
    val_json = os.path.join(data_path, "val.json")
    val_images = None
    val_targets = None
    if eval_every_n_epochs and os.path.exists(val_json):
        import torch as _torch
        from PIL import Image

        with open(val_json) as f:
            val_coco = json.load(f)
        images_by_id = {im["id"]: im for im in val_coco["images"]}
        anns_by_image: dict[int, list] = {}
        for a in val_coco["annotations"]:
            anns_by_image.setdefault(a["image_id"], []).append(a)
        val_images = []
        val_targets = []
        for img_id, meta in images_by_id.items():
            path = os.path.join(images_root, os.path.basename(meta["file_name"]))
            if not os.path.exists(path):
                path = os.path.join(images_root, meta["file_name"])
            val_images.append(Image.open(path).convert("RGB"))
            boxes_xyxy = []
            labels = []
            for a in anns_by_image.get(img_id, []):
                x, y, w, h = a["bbox"]
                boxes_xyxy.append([x, y, x + w, y + h])
                labels.append(a["category_id"])
            val_targets.append({
                "boxes": _torch.tensor(boxes_xyxy, dtype=_torch.float32).reshape(-1, 4),
                "labels": _torch.tensor(labels, dtype=_torch.long),
            })
        log.info(f"Val examples for periodic eval: {len(val_images)}")

    # -- Processor + model --
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModelForObjectDetection.from_pretrained(
        model_name,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(
        f"Parameters: {trainable_params:,} / {total_params:,} "
        f"({trainable_params / total_params * 100:.1f}%)"
    )

    # -- Collator — runs the image processor on each batch --
    def collate_fn(batch):
        images = [b["image"] for b in batch]
        targets = [b["target"] for b in batch]
        enc = processor(images=images, annotations=targets, return_tensors="pt")
        return {"pixel_values": enc["pixel_values"], "labels": enc["labels"]}

    # -- Sanity check: peek at one batch and verify class_labels fit --
    sample = collate_fn([train_ds[i] for i in range(min(2, len(train_ds)))])
    all_labels = []
    for lbl in sample["labels"]:
        all_labels.extend(lbl["class_labels"].tolist())
    log.info(
        f"Sanity check — class_labels in first batch: {sorted(set(all_labels))} | "
        f"model num_labels: {model.config.num_labels} | "
        f"id2label: {model.config.id2label}"
    )
    if all_labels and max(all_labels) >= model.config.num_labels:
        raise ValueError(
            f"class_label {max(all_labels)} out of range for num_labels="
            f"{model.config.num_labels}. Check category id remapping in prepare_data."
        )

    # -- Collect training metrics and update the report chart live.
    # trainer.train() runs in a background thread (via asyncio.to_thread),
    # so the asyncio event loop stays free. We use run_coroutine_threadsafe
    # to push report updates from the callback thread onto that loop.
    training_log: list[dict] = []
    eval_log: list[dict] = []  # periodic mAP checkpoints (epoch, map, map_50)
    loop = asyncio.get_running_loop()

    cat_badges = " ".join(
        f'<span class="badge badge-info">{name}</span>'
        for name in id2label.values()
    )

    def _build_training_report(max_steps: int) -> str:
        """Build the live training report HTML from current training_log."""
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{len(train_ds)}</div><div class="label">Train Examples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>
          <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
        </div>
        <p>Categories: {cat_badges}</p>
        """

        charts_html = ""
        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs} |
              Loss: <span class="highlight">{current['loss']:.4f}</span>
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_chart = _make_line_chart(
                data=training_log,
                x_key="epoch",
                y_keys=["loss"],
                title="Training Loss",
                x_label="Epoch",
                y_label="Loss",
                colors=["#5a7db5"],
            )
            charts_html += f'<div class="chart-container">{loss_chart}</div>'

            if eval_every_n_epochs:
                # Match x-axis to the loss chart: start at 0, end at current epoch
                current_max_epoch = current["epoch"]
                map_chart = _make_line_chart(
                    data=eval_log,
                    x_key="epoch",
                    y_keys=["map", "map_50"],
                    title="Validation mAP (periodic)",
                    x_label="Epoch",
                    y_label="mAP",
                    colors=["#0f3460", "#06d6a0"],
                    y_max_cap=1.0,
                    y_display_names={"map": "mAP (0.50:0.95)", "map_50": "mAP@50"},
                    x_range_override=(0, current_max_epoch),
                )
                charts_html += f'<div class="chart-container">{map_chart}</div>'

            if "lr" in training_log[0]:
                lr_chart = _make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["lr"],
                    title="Learning Rate Schedule",
                    x_label="Epoch",
                    y_label="LR",
                    colors=["#0f3460"],
                )
                charts_html += f'<div class="chart-container">{lr_chart}</div>'

        return _wrap_report(stats_html + charts_html)

    class MetricsCallback(TrainerCallback):
        def __init__(self):
            self._last_eval_epoch = 0

        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs or "loss" not in logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
                "loss": round(logs["loss"], 4),
            }
            if "learning_rate" in logs:
                entry["lr"] = logs["learning_rate"]
            if "grad_norm" in logs:
                entry["grad_norm"] = round(float(logs["grad_norm"]), 4)
            training_log.append(entry)
            log.info(
                f"step={state.global_step}/{state.max_steps} "
                f"epoch={entry['epoch']:.2f} "
                f"loss={entry['loss']:.4f}"
            )

            # Push a live report update onto the asyncio event loop.
            # do_flush=True dispatches the update to the UI immediately.
            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

        def on_epoch_end(self, args, state, control, model=None, **kwargs):
            if not eval_every_n_epochs or val_images is None:
                return
            current_epoch = round(state.epoch)
            if current_epoch % eval_every_n_epochs != 0:
                return
            if current_epoch == self._last_eval_epoch:
                return
            self._last_eval_epoch = current_epoch

            log.info(f"Running periodic mAP eval at epoch {current_epoch}...")
            from torchmetrics.detection.mean_ap import MeanAveragePrecision

            device = next(model.parameters()).device
            # _run_inference sets model.eval(); restore train mode after.
            preds = _run_inference(model, processor, val_images, device, threshold=0.3)
            model.train()

            formatted = [
                {"boxes": p["boxes"], "scores": p["scores"], "labels": p["labels"]}
                for p in preds
            ]
            metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
            metric.update(formatted, val_targets)
            result = metric.compute()
            map_val = round(result["map"].item(), 4)
            map_50 = round(result["map_50"].item(), 4)

            eval_log.append({
                "epoch": current_epoch,
                "map": map_val,
                "map_50": map_50,
            })
            log.info(f"Epoch {current_epoch} — mAP: {map_val:.4f}, mAP@50: {map_50:.4f}")

            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=lr,
        weight_decay=weight_decay,
        logging_steps=5,
        save_strategy="no",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        warmup_ratio=0.1,
        remove_unused_columns=False,
        dataloader_num_workers=2,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        data_collator=collate_fn,
        callbacks=[MetricsCallback()],
    )

    log.info("Starting training...")
    # Run the sync HF training loop in a thread so the asyncio event loop
    # stays free for Flyte's syncify bridge.
    await asyncio.to_thread(trainer.train)
    log.info("Training complete.")

    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")
    trainer.save_model(save_dir)
    processor.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Build final training report --
    stats_html = f"""
    <h2>Training Complete</h2>
    <h3>{model_name}</h3>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(train_ds)}</div><div class="label">Train Examples</div></div>
      <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
      <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
      <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
      <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>
      <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
    </div>
    <p>Categories: {cat_badges}</p>
    """

    charts_html = ""
    if training_log:
        final_loss = training_log[-1]["loss"]
        min_loss = min(d["loss"] for d in training_log)
        initial_loss = training_log[0]["loss"]
        total_steps = training_log[-1]["step"]

        charts_html += f"""
        <div class="card">
          <b>Training Summary:</b>
          Initial loss: {initial_loss:.4f} |
          Final loss: <span class="highlight">{final_loss:.4f}</span> |
          Min loss: {min_loss:.4f} |
          Total steps: {total_steps}
        </div>
        """

        epoch_range = (0, epochs)

        loss_chart = _make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["loss"],
            title="Training Loss",
            x_label="Epoch",
            y_label="Loss",
            colors=["#5a7db5"],
            x_range_override=epoch_range,
        )
        charts_html += f'<div class="chart-container">{loss_chart}</div>'

    if eval_log:
        map_chart = _make_line_chart(
            data=eval_log,
            x_key="epoch",
            y_keys=["map", "map_50"],
            title="Validation mAP (periodic)",
            x_label="Epoch",
            y_label="mAP",
            colors=["#0f3460", "#06d6a0"],
            y_max_cap=1.0,
            x_range_override=(0, epochs),
            y_display_names={"map": "mAP (0.50:0.95)", "map_50": "mAP@50"},
        )
        charts_html += f'<div class="chart-container">{map_chart}</div>'

    if training_log and "lr" in training_log[0]:
        lr_chart = _make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["lr"],
            title="Learning Rate Schedule",
            x_label="Epoch",
            y_label="LR",
            colors=["#0f3460"],
            x_range_override=(0, epochs),
        )
        charts_html += f'<div class="chart-container">{lr_chart}</div>'

    await flyte.report.replace.aio(_wrap_report(stats_html + charts_html), do_flush=True)

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Inference helpers
# ------------------------------------------------------------------

def _run_inference(model, processor, images, device, threshold: float = 0.3):
    """Run object detection on a list of PIL images. Returns list of dicts."""
    import torch

    results = []
    model.eval()
    for img in images:
        inputs = processor(images=img, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        target_size = torch.tensor([img.size[::-1]], device=device)  # (h, w)
        post = processor.post_process_object_detection(
            outputs, target_sizes=target_size, threshold=threshold
        )[0]
        results.append(
            {
                "scores": post["scores"].cpu(),
                "labels": post["labels"].cpu(),
                "boxes": post["boxes"].cpu(),  # xyxy in original image coords
            }
        )
    return results

def _draw_boxes(image, boxes, labels, scores, id2label, color: str = "lime"):
    """Draw bounding boxes on a PIL image. Returns a new PIL image."""
    from PIL import ImageDraw, ImageFont

    img = image.copy()
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, img.width // 60))
    except Exception:
        font = ImageFont.load_default()

    for box, label, score in zip(boxes.tolist(), labels.tolist(), scores.tolist()):
        x0, y0, x1, y1 = box
        width = max(2, img.width // 400)
        draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
        name = id2label.get(int(label), str(int(label)))
        caption = f"{name} {score:.2f}"
        text_bg = draw.textbbox((x0, y0), caption, font=font)
        draw.rectangle(text_bg, fill=color)
        draw.text((x0, y0), caption, fill="black", font=font)
    return img

def _img_to_data_uri(img, max_dim: int = 800) -> str:
    """PIL image → base64 data URI, downscaled for the report."""
    w, h = img.size
    if max(w, h) > max_dim:
        scale = max_dim / max(w, h)
        img = img.resize((int(w * scale), int(h * scale)))
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()

# ------------------------------------------------------------------
# Task 3: Evaluate — COCO mAP on fine-tuned model
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    threshold: float = 0.5,
) -> str:
    """Compute COCO mAP for the fine-tuned model on the val split."""
    import torch
    from PIL import Image
    from torchmetrics.detection.mean_ap import MeanAveragePrecision
    from transformers import AutoImageProcessor, AutoModelForObjectDetection

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(_wrap_report(
        "<h2>Evaluation</h2><p>Loading val split and scoring model...</p>"
    ), do_flush=True)

    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    val_json = os.path.join(data_path, "val.json")

    with open(val_json) as f:
        val_coco = json.load(f)

    images_by_id = {im["id"]: im for im in val_coco["images"]}
    anns_by_image: dict[int, list] = {}
    for a in val_coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    pil_images = []
    targets = []
    for img_id, meta in images_by_id.items():
        path = os.path.join(images_root, os.path.basename(meta["file_name"]))
        if not os.path.exists(path):
            path = os.path.join(images_root, meta["file_name"])
        pil_images.append(Image.open(path).convert("RGB"))
        boxes_xyxy = []
        labels = []
        for a in anns_by_image.get(img_id, []):
            x, y, w, h = a["bbox"]
            boxes_xyxy.append([x, y, x + w, y + h])
            labels.append(a["category_id"])
        targets.append(
            {
                "boxes": torch.tensor(boxes_xyxy, dtype=torch.float32).reshape(-1, 4),
                "labels": torch.tensor(labels, dtype=torch.long),
            }
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    ft_path = await finetuned_dir.download()
    log.info(f"Scoring fine-tuned model: {ft_path}")
    processor = AutoImageProcessor.from_pretrained(ft_path)
    model = AutoModelForObjectDetection.from_pretrained(ft_path).to(device)
    preds = _run_inference(model, processor, pil_images, device, threshold=threshold)

    formatted_preds = [
        {"boxes": p["boxes"], "scores": p["scores"], "labels": p["labels"]}
        for p in preds
    ]

    metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
    metric.update(formatted_preds, targets)

    def to_python(v):
        if hasattr(v, "numel"):
            return v.item() if v.numel() == 1 else v.tolist()
        return v

    ft_metrics = {k: to_python(v) for k, v in metric.compute().items()}
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    log.info(f"Fine-tuned mAP: {ft_metrics.get('map', 0):.3f}")

    metric_keys = ["map", "map_50", "map_75", "mar_10"]
    metric_display = {
        "map": "mAP",
        "map_50": "mAP@50",
        "map_75": "mAP@75",
        "mar_10": "mAR@10",
    }

    rows = []
    for key in metric_keys:
        ft_val = ft_metrics.get(key, 0)
        rows.append(
            f"<tr><td><b>{metric_display.get(key, key)}</b></td>"
            f"<td class='highlight'>{ft_val:.3f}</td></tr>"
        )
    table = (
        "<table><tr><th>Metric</th><th>Score</th></tr>"
        + "".join(rows)
        + "</table>"
    )

    bar_chart = _make_bar_chart(
        labels=[metric_display.get(k, k) for k in metric_keys],
        series={"Fine-tuned": [ft_metrics.get(k, 0) for k in metric_keys]},
        title="COCO Evaluation Metrics",
        colors=["#0f3460"],
        y_max_cap=1.0,
    )

    ft_map = ft_metrics.get("map", 0)
    ft_map50 = ft_metrics.get("map_50", 0)

    eval_html = f"""
    <h2>Evaluation — COCO mAP</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(pil_images)}</div><div class="label">Val Images</div></div>
      <div class="stat"><div class="value">{threshold}</div><div class="label">Threshold</div></div>
      <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
      <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
    </div>
    <div class="chart-container">{bar_chart}</div>
    {table}
    <div class="note">
      <b>mAP</b> (mean Average Precision) measures how accurately the model
      detects objects — balancing whether predictions are correct (precision)
      and whether all objects are found (recall). The @50 and @75 variants
      require IoU overlaps of 50% and 75% between predicted and ground-truth
      boxes. <b>mAR</b> (mean Average Recall) measures how many ground-truth
      objects the model finds, with @1 and @10 limiting detections to 1 or 10
      per image.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(eval_html), do_flush=True)

    return json.dumps(
        {
            "finetuned": {k: round(v, 4) for k, v in ft_metrics.items() if isinstance(v, (int, float))},
            "num_val_images": len(pil_images),
        }
    )

# ------------------------------------------------------------------
# Task 4: Inference demo — render bboxes on val images
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def inference_demo(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    threshold: float = 0.5,
    max_images: int = 8,
    metrics_json: str = "{}",
) -> str:
    """Run the fine-tuned model on val images, render bboxes, embed in the report."""
    import torch
    from PIL import Image
    from torchmetrics.detection.mean_ap import MeanAveragePrecision
    from transformers import AutoImageProcessor, AutoModelForObjectDetection

    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    val_json = os.path.join(data_path, "val.json")

    with open(val_json) as f:
        val_coco = json.load(f)

    id2label = {c["id"]: c["name"] for c in val_coco["categories"]}
    metas = val_coco["images"][:max_images]
    anns_by_image: dict[int, list] = {}
    for a in val_coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    pil_images = []
    gt_per_image = []
    for meta in metas:
        path = os.path.join(images_root, os.path.basename(meta["file_name"]))
        if not os.path.exists(path):
            path = os.path.join(images_root, meta["file_name"])
        pil_images.append(Image.open(path).convert("RGB"))

        boxes_xyxy = []
        labels = []
        for a in anns_by_image.get(meta["id"], []):
            x, y, w, h = a["bbox"]
            boxes_xyxy.append([x, y, x + w, y + h])
            labels.append(a["category_id"])
        gt_per_image.append(
            {
                "boxes": torch.tensor(boxes_xyxy, dtype=torch.float32).reshape(-1, 4),
                "labels": torch.tensor(labels, dtype=torch.long),
                "scores": torch.ones(len(labels)),
            }
        )

    ft_path = await finetuned_dir.download()
    processor = AutoImageProcessor.from_pretrained(ft_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForObjectDetection.from_pretrained(ft_path).to(device)

    preds = _run_inference(model, processor, pil_images, device, threshold=threshold)

    html_blocks = []
    total_gt = 0
    total_pred = 0
    for i, (img, pred, gt) in enumerate(zip(pil_images, preds, gt_per_image)):
        n_gt = len(gt["labels"])
        n_pred = len(pred["labels"])
        total_gt += n_gt
        total_pred += n_pred

        # Per-image mAP
        metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
        metric.update(
            [{"boxes": pred["boxes"], "scores": pred["scores"], "labels": pred["labels"]}],
            [{"boxes": gt["boxes"], "labels": gt["labels"]}],
        )
        img_metrics = metric.compute()
        img_map = img_metrics["map"].item()
        img_map_badge = (
            f'<span class="badge badge-success">mAP {img_map:.2f}</span>'
            if img_map >= 0.5
            else f'<span class="badge badge-info">mAP {img_map:.2f}</span>'
        )

        pred_img = _draw_boxes(
            img, pred["boxes"], pred["labels"], pred["scores"],
            id2label, color="lime",
        )
        gt_img = _draw_boxes(
            img, gt["boxes"], gt["labels"], gt["scores"],
            id2label, color="dodgerblue",
        )
        html_blocks.append(f"""
        <div class="card">
          <b>Image {i + 1}</b> {img_map_badge}
          <div class="img-pair">
            <div>
              <p><span class="gt-label">Ground Truth</span>
                 <span class="badge badge-info">{n_gt} boxes</span></p>
              <img src="{_img_to_data_uri(gt_img)}" />
            </div>
            <div>
              <p><span class="pred-label">Predictions</span>
                 <span class="badge badge-success">{n_pred} boxes</span>
                 (threshold={threshold})</p>
              <img src="{_img_to_data_uri(pred_img)}" />
            </div>
          </div>
        </div>""")

    # Parse metrics if provided (from evaluate task)
    metrics = json.loads(metrics_json)
    ft_metrics = metrics.get("finetuned", {})
    ft_map = ft_metrics.get("map", None)
    ft_map50 = ft_metrics.get("map_50", None)

    metrics_stats = ""
    if ft_map is not None:
        metrics_stats = f"""
        <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
        <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
        """

    demo_html = f"""
    <h2>Inference Demo</h2>
    <h3>Fine-tuned RT-DETR on validation images</h3>
    <div class="stat-grid">
      {metrics_stats}
      <div class="stat"><div class="value">{len(pil_images)}</div><div class="label">Images Shown</div></div>
      <div class="stat"><div class="value">{total_gt}</div><div class="label">Ground Truth Boxes</div></div>
      <div class="stat"><div class="value">{total_pred}</div><div class="label">Predicted Boxes</div></div>
      <div class="stat"><div class="value">{threshold}</div><div class="label">Confidence Threshold</div></div>
    </div>
    <p><span class="gt-label">Blue = ground truth</span> |
       <span class="pred-label">Green = predictions</span></p>
    {"".join(html_blocks)}
    """

    await flyte.report.replace.aio(_wrap_report(demo_html), do_flush=True)

    return json.dumps(
        {
            "num_images": len(pil_images),
            "predictions_per_image": [len(p["labels"]) for p in preds],
        }
    )

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "PekingU/rtdetr_v2_r18vd",
    dataset_repo: str = "sagecodes/union_flyte_swag_object_detection",
    annotations_path: str = "swag/train.json",
    images_subdir: str = "swag/images",
    epochs: int = 30,
    lr: float = 5e-5,
    batch_size: int = 4,
    val_fraction: float = 0.2,
    threshold: float = 0.5,
    demo_images: int = 8,
    eval_every_n_epochs: int | None = None,
) -> tuple[flyte.io.Dir, str]:
    """
    End-to-end RT-DETRv2 fine-tuning pipeline.

    Returns the fine-tuned model directory and a JSON summary.

    1. Download COCO dataset from HuggingFace and split train/val
    2. Fine-tune RT-DETRv2 on the train split
    3. Evaluate: COCO mAP comparison (base vs fine-tuned)
    4. Inference demo: render bounding boxes on val images
    """
    log.info(f"Pipeline: {model_name} | dataset={dataset_repo}")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = ["Preparing Data", "Fine-tuning", "Evaluating", "Inference Demo"]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#06d6a0;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#e94560;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>RT-DETRv2 Object Detection Pipeline</h2>
        <p><b>Model:</b> {model_name} | <b>Dataset:</b> {dataset_repo}</p>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Downloading and splitting dataset...")),
        do_flush=True,
    )

    data_dir = await prepare_data(
        dataset_repo=dataset_repo,
        annotations_path=annotations_path,
        images_subdir=images_subdir,
        val_fraction=val_fraction,
    )

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Fine-tuning model...")),
        do_flush=True,
    )

    finetuned_dir = await train(
        model_name, data_dir, epochs, lr, batch_size,
        eval_every_n_epochs=eval_every_n_epochs,
    )

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Running COCO mAP evaluation...")),
        do_flush=True,
    )

    metrics_json = await evaluate(finetuned_dir, data_dir, threshold)
    metrics = json.loads(metrics_json)

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Rendering bounding box demo...")),
        do_flush=True,
    )

    demo_json = await inference_demo(
        finetuned_dir, data_dir, threshold, demo_images,
        metrics_json=metrics_json,
    )

    ft_map = metrics["finetuned"].get("map", 0)
    ft_map50 = metrics["finetuned"].get("map_50", 0)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <h3>{model_name}</h3>
    <div class="stat-grid">
      <div class="stat"><div class="value">{metrics['num_val_images']}</div><div class="label">Val Images</div></div>
      <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
      <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
    </div>
    <div class="card">
      <b>Configuration:</b> {epochs} epochs | LR {lr} | Batch size {batch_size} |
      Val fraction {val_fraction} | Threshold {threshold}
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)

    log.info(f"Pipeline complete. Fine-tuned mAP: {ft_map:.3f}")
    return finetuned_dir, json.dumps({"metrics": metrics, "demo": json.loads(demo_json)})

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/detr_object_detection/detr_object_detection.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "transformers>=4.49.0",
#    "albumentations>=1.4.0",
#    "torchmetrics>=1.4.0",
#    ...
# ]
# ///
```

## Orchestrate the pipeline

The `pipeline` task prepares data, fine-tunes RT-DETR, evaluates mAP, and renders an inference demo.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.9.0",
#    "torchvision>=0.24.0",
#    "transformers>=4.49.0",
#    "accelerate>=0.34.0",
#    "huggingface_hub>=0.24.0",
#    "datasets>=3.0.0",
#    "pillow>=10.0.0",
#    "albumentations>=1.4.0",
#    "torchmetrics>=1.4.0",
#    "pycocotools>=2.0.7",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import asyncio
import base64
import io
import json
import logging
import os
import random
import shutil
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="detr-object-detection", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="detr-object-detection-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
)

cpu_env = flyte.TaskEnvironment(
    name="detr-object-detection-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="6Gi"),
    depends_on=[gpu_env],
)
# {{/docs-fragment env}}

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Report styling — shared CSS for all task reports
# ------------------------------------------------------------------

REPORT_CSS = """
<style>
  .report { font-family: system-ui, -apple-system, sans-serif; max-width: 960px; margin: 0 auto; color: #1a1a2e; }
  .report h2 { color: #16213e; border-bottom: 2px solid #0f3460; padding-bottom: 8px; margin-top: 24px; }
  .report h3 { color: #0f3460; margin-top: 20px; }
  .report .card { background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 8px; padding: 16px; margin: 12px 0; }
  .report .stat-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin: 12px 0; }
  .report .stat { background: #fff; border: 1px solid #e9ecef; border-radius: 6px; padding: 12px; text-align: center; }
  .report .stat .value { font-size: 1.5em; font-weight: 700; color: #0f3460; }
  .report .stat .label { font-size: 0.85em; color: #6c757d; margin-top: 4px; }
  .report table { border-collapse: collapse; width: 100%; margin: 12px 0; }
  .report th { background: #0f3460; color: #fff; padding: 10px 14px; text-align: left; font-weight: 600; }
  .report td { padding: 8px 14px; border-bottom: 1px solid #dee2e6; }
  .report tr:nth-child(even) { background: #f8f9fa; }
  .report .highlight { color: #0f3460; font-weight: 700; }
  .report .note { background: #fff3cd; border-left: 4px solid #ffc107; padding: 10px 14px; border-radius: 4px; margin: 12px 0; font-size: 0.9em; }
  .report .img-pair { display: flex; gap: 12px; margin: 16px 0; flex-wrap: wrap; }
  .report .img-pair > div { flex: 1; min-width: 300px; }
  .report .img-pair img { width: 100%; border-radius: 6px; border: 1px solid #dee2e6; }
  .report .img-pair .gt-label { color: #5a7db5; font-weight: 600; }
  .report .img-pair .pred-label { color: #06d6a0; font-weight: 600; }
  .report .badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.8em; font-weight: 600; }
  .report .badge-success { background: #d4edda; color: #155724; }
  .report .badge-info { background: #d1ecf1; color: #0c5460; }
  .report .chart-container { background: #fff; border: 1px solid #dee2e6; border-radius: 8px; padding: 16px; margin: 16px 0; }
</style>
"""

def _wrap_report(html: str) -> str:
    """Wrap HTML content with report styling."""
    return f'{REPORT_CSS}<div class="report">{html}</div>'

# ------------------------------------------------------------------
# SVG chart helpers — lightweight charts without matplotlib
# ------------------------------------------------------------------

def _make_line_chart(
    data: list[dict],
    x_key: str,
    y_keys: list[str],
    title: str = "",
    x_label: str = "",
    y_label: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    y_max_cap: float | None = None,
    x_range_override: tuple[float, float] | None = None,
    y_display_names: dict[str, str] | None = None,
) -> str:
    """Generate an SVG line chart from a list of dicts.

    Args:
        data: List of dicts, each with x_key and y_keys values.
        x_key: Key for x-axis values.
        y_keys: Keys for y-axis series to plot.
        title: Chart title.
        x_label: X-axis label.
        y_label: Y-axis label.
        colors: Colors for each series (defaults to a built-in palette).
        width: SVG width in pixels.
        height: SVG height in pixels.
        y_max_cap: If set, cap the y-axis at this value (e.g. 1.0 for mAP).
        x_range_override: If set, force the x-axis to this (min, max) range.

    Returns:
        SVG string.
    """

    default_colors = ["#5a7db5", "#0f3460", "#06d6a0", "#ffc107", "#6c757d"]
    colors = colors or default_colors

    # Chart area margins
    ml, mr, mt, mb = 60, 20, 40, 50
    cw = width - ml - mr
    ch = height - mt - mb

    x_vals = [d[x_key] for d in data] if data else []
    if x_range_override:
        x_min, x_max = x_range_override
    elif x_vals:
        x_min, x_max = min(x_vals), max(x_vals)
    else:
        x_min, x_max = 0, 1
    x_range = x_max - x_min or 1

    # Compute y range across all series
    all_y = []
    for key in y_keys:
        all_y.extend(d[key] for d in data if key in d)
    y_min = min(all_y) if all_y else 0
    y_max = max(all_y) if all_y else 1
    y_pad = (y_max - y_min) * 0.1 or 0.1
    y_min_plot = max(0, y_min - y_pad)
    y_max_plot = y_max + y_pad
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap)
    y_range = y_max_plot - y_min_plot or 1

    def sx(v):
        return ml + (v - x_min) / x_range * cw

    def sy(v):
        return mt + ch - (v - y_min_plot) / y_range * ch

    # Build SVG
    lines = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        # Background
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines (5 horizontal)
    for i in range(6):
        y_tick = y_min_plot + y_range * i / 5
        py = sy(y_tick)
        lines.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        lines.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.3f}</text>'
        )

    # Axes
    lines.append(
        f'<line x1="{ml}" y1="{mt}" x2="{ml}" y2="{mt + ch}" '
        f'stroke="#adb5bd" stroke-width="1.5"/>'
    )
    lines.append(
        f'<line x1="{ml}" y1="{mt + ch}" x2="{ml + cw}" y2="{mt + ch}" '
        f'stroke="#adb5bd" stroke-width="1.5"/>'
    )

    # X-axis ticks
    if x_vals:
        n_x_ticks = min(len(data), 10)
        step = max(1, len(data) // n_x_ticks)
        for i in range(0, len(data), step):
            px = sx(x_vals[i])
            lines.append(
                f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
                f'font-size="11" fill="#6c757d">{x_vals[i]:.0f}</text>'
            )
    else:
        # Empty chart — generate evenly spaced ticks from x range
        for i in range(6):
            x_tick = x_min + x_range * i / 5
            px = sx(x_tick)
            lines.append(
                f'<text x="{px:.1f}" y="{mt + ch + 20}" text-anchor="middle" '
                f'font-size="11" fill="#6c757d">{x_tick:.0f}</text>'
            )

    # Plot each series
    if not data:
        # Empty chart placeholder
        lines.append(
            f'<text x="{ml + cw / 2}" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="13" fill="#adb5bd" font-style="italic">Waiting for data...</text>'
        )
    for si, key in enumerate(y_keys):
        color = colors[si % len(colors)]
        points = [(sx(d[x_key]), sy(d[key])) for d in data if key in d]
        if not points:
            continue
        # Draw line if we have 2+ points (dash odd series for visibility)
        if len(points) >= 2:
            path_d = f"M {points[0][0]:.1f},{points[0][1]:.1f}"
            for px, py in points[1:]:
                path_d += f" L {px:.1f},{py:.1f}"
            dash = ' stroke-dasharray="6,3"' if si % 2 == 1 else ""
            lines.append(
                f'<path d="{path_d}" fill="none" stroke="{color}" '
                f'stroke-width="2" stroke-linejoin="round"{dash}/>'
            )
        # Always show dots for sparse data (including single points)
        if len(points) <= 30:
            for px, py in points:
                lines.append(
                    f'<circle cx="{px:.1f}" cy="{py:.1f}" r="3" fill="{color}"/>'
                )

    # Title
    if title:
        lines.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Axis labels
    if x_label:
        lines.append(
            f'<text x="{ml + cw / 2}" y="{height - 6}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d">{x_label}</text>'
        )
    if y_label:
        lines.append(
            f'<text x="14" y="{mt + ch / 2}" text-anchor="middle" '
            f'font-size="12" fill="#6c757d" '
            f'transform="rotate(-90, 14, {mt + ch / 2})">{y_label}</text>'
        )

    # Legend
    names = y_display_names or {}
    if len(y_keys) > 1:
        lx = ml + 10
        for si, key in enumerate(y_keys):
            color = colors[si % len(colors)]
            ly = mt + 14 + si * 18
            lines.append(
                f'<rect x="{lx}" y="{ly - 6}" width="12" height="12" '
                f'rx="2" fill="{color}"/>'
            )
            label = names.get(key, key)
            lines.append(
                f'<text x="{lx + 16}" y="{ly + 4}" font-size="11" '
                f'fill="#1a1a2e">{label}</text>'
            )

    lines.append("</svg>")
    return "\n".join(lines)

def _make_bar_chart(
    labels: list[str],
    series: dict[str, list[float]],
    title: str = "",
    colors: list[str] | None = None,
    width: int = 700,
    height: int = 300,
    y_max_cap: float | None = None,
) -> str:
    """Generate an SVG grouped bar chart.

    Args:
        labels: Category labels for x-axis.
        series: Dict mapping series name to list of values (same length as labels).
        title: Chart title.
        colors: Colors for each series.
        width: SVG width.
        height: SVG height.
        y_max_cap: If set, cap the y-axis at this value (e.g. 1.0 for mAP).

    Returns:
        SVG string.
    """
    if not labels:
        return ""

    default_colors = ["#adb5bd", "#0f3460", "#06d6a0", "#5a7db5"]
    colors = colors or default_colors

    ml, mr, mt, mb = 60, 20, 40, 60
    cw = width - ml - mr
    ch = height - mt - mb

    all_vals = [v for vals in series.values() for v in vals]
    y_max = max(all_vals) if all_vals else 1
    y_max_plot = y_max * 1.15 or 1
    if y_max_cap is not None:
        y_max_plot = min(y_max_plot, y_max_cap) or y_max_cap

    n_groups = len(labels)
    n_series = len(series)
    group_width = cw / n_groups
    bar_width = group_width * 0.7 / max(n_series, 1)
    gap = group_width * 0.15

    def sy(v):
        return mt + ch - (v / y_max_plot) * ch

    lines_svg = [
        f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" '
        f'style="width:100%;max-width:{width}px;height:auto;">',
        f'<rect width="{width}" height="{height}" fill="#fff" rx="6"/>',
    ]

    # Grid lines
    for i in range(6):
        y_tick = y_max_plot * i / 5
        py = sy(y_tick)
        lines_svg.append(
            f'<line x1="{ml}" y1="{py:.1f}" x2="{ml + cw}" y2="{py:.1f}" '
            f'stroke="#e9ecef" stroke-width="1"/>'
        )
        lines_svg.append(
            f'<text x="{ml - 8}" y="{py + 4:.1f}" text-anchor="end" '
            f'font-size="11" fill="#6c757d">{y_tick:.3f}</text>'
        )

    # Bars
    for gi, label in enumerate(labels):
        gx = ml + gi * group_width + gap
        for si, (name, vals) in enumerate(series.items()):
            color = colors[si % len(colors)]
            bx = gx + si * bar_width
            val = vals[gi]
            by = sy(val)
            bh = mt + ch - by
            lines_svg.append(
                f'<rect x="{bx:.1f}" y="{by:.1f}" width="{bar_width - 1:.1f}" '
                f'height="{bh:.1f}" fill="{color}" rx="2"/>'
            )
            # Value label on top of bar
            lines_svg.append(
                f'<text x="{bx + bar_width / 2:.1f}" y="{by - 4:.1f}" '
                f'text-anchor="middle" font-size="10" fill="#1a1a2e">'
                f'{val:.3f}</text>'
            )
        # Group label
        lines_svg.append(
            f'<text x="{gx + n_series * bar_width / 2:.1f}" y="{mt + ch + 18}" '
            f'text-anchor="middle" font-size="11" fill="#6c757d">{label}</text>'
        )

    # Title
    if title:
        lines_svg.append(
            f'<text x="{width / 2}" y="22" text-anchor="middle" '
            f'font-size="14" font-weight="600" fill="#1a1a2e">{title}</text>'
        )

    # Legend
    lx = ml + cw - len(series) * 100
    for si, name in enumerate(series):
        color = colors[si % len(colors)]
        lines_svg.append(
            f'<rect x="{lx + si * 100}" y="{mt + ch + 35}" width="12" '
            f'height="12" rx="2" fill="{color}"/>'
        )
        lines_svg.append(
            f'<text x="{lx + si * 100 + 16}" y="{mt + ch + 46}" font-size="11" '
            f'fill="#1a1a2e">{name}</text>'
        )

    lines_svg.append("</svg>")
    return "\n".join(lines_svg)

# ------------------------------------------------------------------
# Task 1: Prepare dataset — download COCO JSON + images, split train/val
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def prepare_data(
    dataset_repo: str = "sagecodes/union_flyte_swag_object_detection",
    annotations_path: str = "swag/train.json",
    images_subdir: str = "swag/images",
    val_fraction: float = 0.2,
    seed: int = 42,
) -> flyte.io.Dir:
    """Download a COCO-format dataset from HF and split into train/val."""
    from huggingface_hub import snapshot_download

    log.info(f"Downloading dataset: {dataset_repo}")
    local_repo = snapshot_download(
        repo_id=dataset_repo,
        repo_type="dataset",
    )

    ann_file = os.path.join(local_repo, annotations_path)
    img_root = os.path.join(local_repo, images_subdir)

    with open(ann_file) as f:
        coco = json.load(f)

    images = coco["images"]
    annotations = coco["annotations"]
    categories = coco["categories"]

    log.info(
        f"Loaded {len(images)} images, {len(annotations)} annotations, "
        f"{len(categories)} categories"
    )
    log.info(f"Raw category ids: {sorted({c['id'] for c in categories})}")
    log.info(
        f"Raw annotation category_ids (unique): "
        f"{sorted({a['category_id'] for a in annotations})}"
    )

    # Remap category ids to contiguous 0..N-1 — required because HF object
    # detection models size their classifier head to len(id2label) and treat
    # class labels as direct indices into that head. Any gap or 1-indexed id
    # causes an IndexKernel OOB inside the focal-loss scatter.
    #
    # Build the remap from the UNION of ids declared in `categories` and ids
    # actually used in `annotations` — some datasets have orphaned annotations
    # referencing categories that aren't declared (this one does).
    declared_ids = {c["id"] for c in categories}
    used_ids = {a["category_id"] for a in annotations}
    orphans = used_ids - declared_ids
    if orphans:
        log.warning(
            f"Annotations reference undeclared category ids {sorted(orphans)} — "
            f"adding stub categories."
        )

    all_cat_ids = sorted(declared_ids | used_ids)
    id_remap = {old: new for new, old in enumerate(all_cat_ids)}
    existing_names = {c["id"]: c["name"] for c in categories}
    categories = [
        {"id": id_remap[old], "name": existing_names.get(old, f"category_{old}")}
        for old in all_cat_ids
    ]
    annotations = [
        {**a, "category_id": id_remap[a["category_id"]]} for a in annotations
    ]
    log.info(f"Remapped category ids: {id_remap}")
    log.info(f"Final categories: {categories}")

    # Split by image id
    rng = random.Random(seed)
    img_ids = [im["id"] for im in images]
    rng.shuffle(img_ids)
    n_val = max(1, int(len(img_ids) * val_fraction))
    val_ids = set(img_ids[:n_val])
    train_ids = set(img_ids[n_val:])

    def filter_coco(keep_ids: set) -> dict:
        return {
            "info": coco.get("info", {}),
            "categories": categories,
            "images": [im for im in images if im["id"] in keep_ids],
            "annotations": [a for a in annotations if a["image_id"] in keep_ids],
        }

    train_coco = filter_coco(train_ids)
    val_coco = filter_coco(val_ids)

    log.info(
        f"Split: {len(train_coco['images'])} train / {len(val_coco['images'])} val images"
    )

    # Pack output dir: images/ + train.json + val.json
    out_dir = tempfile.mkdtemp(prefix="coco_split_")
    out_img = os.path.join(out_dir, "images")
    shutil.copytree(img_root, out_img)

    with open(os.path.join(out_dir, "train.json"), "w") as f:
        json.dump(train_coco, f)
    with open(os.path.join(out_dir, "val.json"), "w") as f:
        json.dump(val_coco, f)

    return await flyte.io.Dir.from_local(out_dir)

# ------------------------------------------------------------------
# Helpers — torch Dataset wrapping COCO JSON
# ------------------------------------------------------------------

def _build_torch_dataset(coco_path: str, images_root: str, augment: bool):
    """Build a torch Dataset that yields {image, target} for the HF image processor."""
    import albumentations as A
    import numpy as np
    from PIL import Image
    from torch.utils.data import Dataset

    with open(coco_path) as f:
        coco = json.load(f)

    images_by_id = {im["id"]: im for im in coco["images"]}
    anns_by_image: dict[int, list] = {}
    for a in coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    image_ids = list(images_by_id.keys())

    # NOTE: we deliberately don't resize here — the HF image processor handles
    # resize+pad. Augmentation only.
    if augment:
        transform = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.1),
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=30, val_shift_limit=20, p=0.4),
                A.Rotate(limit=15, border_mode=0, p=0.4),
                A.RandomScale(scale_limit=0.2, p=0.4),
                A.GaussianBlur(blur_limit=(3, 5), p=0.2),
                A.GaussNoise(p=0.2),
            ],
            bbox_params=A.BboxParams(
                format="coco",
                label_fields=["category"],
                min_area=4,
                min_visibility=0.1,
                clip=True,
            ),
        )
    else:
        transform = A.Compose(
            [A.NoOp()],
            bbox_params=A.BboxParams(format="coco", label_fields=["category"], clip=True),
        )

    class CocoDataset(Dataset):
        def __len__(self) -> int:
            return len(image_ids)

        def __getitem__(self, idx: int):
            img_id = image_ids[idx]
            meta = images_by_id[img_id]
            img_path = os.path.join(images_root, os.path.basename(meta["file_name"]))
            if not os.path.exists(img_path):
                img_path = os.path.join(images_root, meta["file_name"])
            image = np.array(Image.open(img_path).convert("RGB"))

            anns = anns_by_image.get(img_id, [])
            bboxes = [a["bbox"] for a in anns]
            categories = [a["category_id"] for a in anns]

            out = transform(image=image, bboxes=bboxes, category=categories)
            image_t = out["image"]
            bboxes_t = out["bboxes"]
            categories_t = out["category"]

            target_anns = []
            for bb, cat in zip(bboxes_t, categories_t):
                x, y, w, h = bb
                target_anns.append(
                    {
                        "image_id": img_id,
                        "category_id": int(cat),
                        "bbox": [float(x), float(y), float(w), float(h)],
                        "area": float(w * h),
                        "iscrowd": 0,
                    }
                )

            return {
                "image": image_t,
                "target": {"image_id": img_id, "annotations": target_anns},
            }

    return CocoDataset(), coco["categories"]

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    epochs: int = 30,
    lr: float = 5e-5,
    batch_size: int = 4,
    weight_decay: float = 1e-4,
    eval_every_n_epochs: int | None = None,
) -> flyte.io.Dir:
    """Fine-tune RT-DETR (or any HuggingFace object-detection model) on COCO data."""
    import torch
    from transformers import (
        AutoImageProcessor,
        AutoModelForObjectDetection,
        Trainer,
        TrainerCallback,
        TrainingArguments,
    )

    log.info(f"Training: model={model_name}")
    await flyte.report.replace.aio(_wrap_report(
        f"<h2>Loading model...</h2><p>{model_name}</p>"
        f"<p>Preparing dataset and initializing weights...</p>"
    ), do_flush=True)

    # -- Load data --
    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    train_json = os.path.join(data_path, "train.json")

    with open(train_json) as f:
        categories = json.load(f)["categories"]
    id2label = {c["id"]: c["name"] for c in categories}
    label2id = {v: k for k, v in id2label.items()}

    train_ds, _ = _build_torch_dataset(train_json, images_root, augment=True)
    log.info(f"Train examples: {len(train_ds)} | Categories: {id2label}")

    # -- Optionally load val set for periodic mAP evaluation --
    val_json = os.path.join(data_path, "val.json")
    val_images = None
    val_targets = None
    if eval_every_n_epochs and os.path.exists(val_json):
        import torch as _torch
        from PIL import Image

        with open(val_json) as f:
            val_coco = json.load(f)
        images_by_id = {im["id"]: im for im in val_coco["images"]}
        anns_by_image: dict[int, list] = {}
        for a in val_coco["annotations"]:
            anns_by_image.setdefault(a["image_id"], []).append(a)
        val_images = []
        val_targets = []
        for img_id, meta in images_by_id.items():
            path = os.path.join(images_root, os.path.basename(meta["file_name"]))
            if not os.path.exists(path):
                path = os.path.join(images_root, meta["file_name"])
            val_images.append(Image.open(path).convert("RGB"))
            boxes_xyxy = []
            labels = []
            for a in anns_by_image.get(img_id, []):
                x, y, w, h = a["bbox"]
                boxes_xyxy.append([x, y, x + w, y + h])
                labels.append(a["category_id"])
            val_targets.append({
                "boxes": _torch.tensor(boxes_xyxy, dtype=_torch.float32).reshape(-1, 4),
                "labels": _torch.tensor(labels, dtype=_torch.long),
            })
        log.info(f"Val examples for periodic eval: {len(val_images)}")

    # -- Processor + model --
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModelForObjectDetection.from_pretrained(
        model_name,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(
        f"Parameters: {trainable_params:,} / {total_params:,} "
        f"({trainable_params / total_params * 100:.1f}%)"
    )

    # -- Collator — runs the image processor on each batch --
    def collate_fn(batch):
        images = [b["image"] for b in batch]
        targets = [b["target"] for b in batch]
        enc = processor(images=images, annotations=targets, return_tensors="pt")
        return {"pixel_values": enc["pixel_values"], "labels": enc["labels"]}

    # -- Sanity check: peek at one batch and verify class_labels fit --
    sample = collate_fn([train_ds[i] for i in range(min(2, len(train_ds)))])
    all_labels = []
    for lbl in sample["labels"]:
        all_labels.extend(lbl["class_labels"].tolist())
    log.info(
        f"Sanity check — class_labels in first batch: {sorted(set(all_labels))} | "
        f"model num_labels: {model.config.num_labels} | "
        f"id2label: {model.config.id2label}"
    )
    if all_labels and max(all_labels) >= model.config.num_labels:
        raise ValueError(
            f"class_label {max(all_labels)} out of range for num_labels="
            f"{model.config.num_labels}. Check category id remapping in prepare_data."
        )

    # -- Collect training metrics and update the report chart live.
    # trainer.train() runs in a background thread (via asyncio.to_thread),
    # so the asyncio event loop stays free. We use run_coroutine_threadsafe
    # to push report updates from the callback thread onto that loop.
    training_log: list[dict] = []
    eval_log: list[dict] = []  # periodic mAP checkpoints (epoch, map, map_50)
    loop = asyncio.get_running_loop()

    cat_badges = " ".join(
        f'<span class="badge badge-info">{name}</span>'
        for name in id2label.values()
    )

    def _build_training_report(max_steps: int) -> str:
        """Build the live training report HTML from current training_log."""
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{len(train_ds)}</div><div class="label">Train Examples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>
          <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
        </div>
        <p>Categories: {cat_badges}</p>
        """

        charts_html = ""
        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs} |
              Loss: <span class="highlight">{current['loss']:.4f}</span>
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_chart = _make_line_chart(
                data=training_log,
                x_key="epoch",
                y_keys=["loss"],
                title="Training Loss",
                x_label="Epoch",
                y_label="Loss",
                colors=["#5a7db5"],
            )
            charts_html += f'<div class="chart-container">{loss_chart}</div>'

            if eval_every_n_epochs:
                # Match x-axis to the loss chart: start at 0, end at current epoch
                current_max_epoch = current["epoch"]
                map_chart = _make_line_chart(
                    data=eval_log,
                    x_key="epoch",
                    y_keys=["map", "map_50"],
                    title="Validation mAP (periodic)",
                    x_label="Epoch",
                    y_label="mAP",
                    colors=["#0f3460", "#06d6a0"],
                    y_max_cap=1.0,
                    y_display_names={"map": "mAP (0.50:0.95)", "map_50": "mAP@50"},
                    x_range_override=(0, current_max_epoch),
                )
                charts_html += f'<div class="chart-container">{map_chart}</div>'

            if "lr" in training_log[0]:
                lr_chart = _make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["lr"],
                    title="Learning Rate Schedule",
                    x_label="Epoch",
                    y_label="LR",
                    colors=["#0f3460"],
                )
                charts_html += f'<div class="chart-container">{lr_chart}</div>'

        return _wrap_report(stats_html + charts_html)

    class MetricsCallback(TrainerCallback):
        def __init__(self):
            self._last_eval_epoch = 0

        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs or "loss" not in logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
                "loss": round(logs["loss"], 4),
            }
            if "learning_rate" in logs:
                entry["lr"] = logs["learning_rate"]
            if "grad_norm" in logs:
                entry["grad_norm"] = round(float(logs["grad_norm"]), 4)
            training_log.append(entry)
            log.info(
                f"step={state.global_step}/{state.max_steps} "
                f"epoch={entry['epoch']:.2f} "
                f"loss={entry['loss']:.4f}"
            )

            # Push a live report update onto the asyncio event loop.
            # do_flush=True dispatches the update to the UI immediately.
            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

        def on_epoch_end(self, args, state, control, model=None, **kwargs):
            if not eval_every_n_epochs or val_images is None:
                return
            current_epoch = round(state.epoch)
            if current_epoch % eval_every_n_epochs != 0:
                return
            if current_epoch == self._last_eval_epoch:
                return
            self._last_eval_epoch = current_epoch

            log.info(f"Running periodic mAP eval at epoch {current_epoch}...")
            from torchmetrics.detection.mean_ap import MeanAveragePrecision

            device = next(model.parameters()).device
            # _run_inference sets model.eval(); restore train mode after.
            preds = _run_inference(model, processor, val_images, device, threshold=0.3)
            model.train()

            formatted = [
                {"boxes": p["boxes"], "scores": p["scores"], "labels": p["labels"]}
                for p in preds
            ]
            metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
            metric.update(formatted, val_targets)
            result = metric.compute()
            map_val = round(result["map"].item(), 4)
            map_50 = round(result["map_50"].item(), 4)

            eval_log.append({
                "epoch": current_epoch,
                "map": map_val,
                "map_50": map_50,
            })
            log.info(f"Epoch {current_epoch} — mAP: {map_val:.4f}, mAP@50: {map_50:.4f}")

            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=lr,
        weight_decay=weight_decay,
        logging_steps=5,
        save_strategy="no",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        warmup_ratio=0.1,
        remove_unused_columns=False,
        dataloader_num_workers=2,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        data_collator=collate_fn,
        callbacks=[MetricsCallback()],
    )

    log.info("Starting training...")
    # Run the sync HF training loop in a thread so the asyncio event loop
    # stays free for Flyte's syncify bridge.
    await asyncio.to_thread(trainer.train)
    log.info("Training complete.")

    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")
    trainer.save_model(save_dir)
    processor.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Build final training report --
    stats_html = f"""
    <h2>Training Complete</h2>
    <h3>{model_name}</h3>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(train_ds)}</div><div class="label">Train Examples</div></div>
      <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
      <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
      <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
      <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>
      <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
    </div>
    <p>Categories: {cat_badges}</p>
    """

    charts_html = ""
    if training_log:
        final_loss = training_log[-1]["loss"]
        min_loss = min(d["loss"] for d in training_log)
        initial_loss = training_log[0]["loss"]
        total_steps = training_log[-1]["step"]

        charts_html += f"""
        <div class="card">
          <b>Training Summary:</b>
          Initial loss: {initial_loss:.4f} |
          Final loss: <span class="highlight">{final_loss:.4f}</span> |
          Min loss: {min_loss:.4f} |
          Total steps: {total_steps}
        </div>
        """

        epoch_range = (0, epochs)

        loss_chart = _make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["loss"],
            title="Training Loss",
            x_label="Epoch",
            y_label="Loss",
            colors=["#5a7db5"],
            x_range_override=epoch_range,
        )
        charts_html += f'<div class="chart-container">{loss_chart}</div>'

    if eval_log:
        map_chart = _make_line_chart(
            data=eval_log,
            x_key="epoch",
            y_keys=["map", "map_50"],
            title="Validation mAP (periodic)",
            x_label="Epoch",
            y_label="mAP",
            colors=["#0f3460", "#06d6a0"],
            y_max_cap=1.0,
            x_range_override=(0, epochs),
            y_display_names={"map": "mAP (0.50:0.95)", "map_50": "mAP@50"},
        )
        charts_html += f'<div class="chart-container">{map_chart}</div>'

    if training_log and "lr" in training_log[0]:
        lr_chart = _make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["lr"],
            title="Learning Rate Schedule",
            x_label="Epoch",
            y_label="LR",
            colors=["#0f3460"],
            x_range_override=(0, epochs),
        )
        charts_html += f'<div class="chart-container">{lr_chart}</div>'

    await flyte.report.replace.aio(_wrap_report(stats_html + charts_html), do_flush=True)

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Inference helpers
# ------------------------------------------------------------------

def _run_inference(model, processor, images, device, threshold: float = 0.3):
    """Run object detection on a list of PIL images. Returns list of dicts."""
    import torch

    results = []
    model.eval()
    for img in images:
        inputs = processor(images=img, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        target_size = torch.tensor([img.size[::-1]], device=device)  # (h, w)
        post = processor.post_process_object_detection(
            outputs, target_sizes=target_size, threshold=threshold
        )[0]
        results.append(
            {
                "scores": post["scores"].cpu(),
                "labels": post["labels"].cpu(),
                "boxes": post["boxes"].cpu(),  # xyxy in original image coords
            }
        )
    return results

def _draw_boxes(image, boxes, labels, scores, id2label, color: str = "lime"):
    """Draw bounding boxes on a PIL image. Returns a new PIL image."""
    from PIL import ImageDraw, ImageFont

    img = image.copy()
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, img.width // 60))
    except Exception:
        font = ImageFont.load_default()

    for box, label, score in zip(boxes.tolist(), labels.tolist(), scores.tolist()):
        x0, y0, x1, y1 = box
        width = max(2, img.width // 400)
        draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
        name = id2label.get(int(label), str(int(label)))
        caption = f"{name} {score:.2f}"
        text_bg = draw.textbbox((x0, y0), caption, font=font)
        draw.rectangle(text_bg, fill=color)
        draw.text((x0, y0), caption, fill="black", font=font)
    return img

def _img_to_data_uri(img, max_dim: int = 800) -> str:
    """PIL image → base64 data URI, downscaled for the report."""
    w, h = img.size
    if max(w, h) > max_dim:
        scale = max_dim / max(w, h)
        img = img.resize((int(w * scale), int(h * scale)))
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()

# ------------------------------------------------------------------
# Task 3: Evaluate — COCO mAP on fine-tuned model
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    threshold: float = 0.5,
) -> str:
    """Compute COCO mAP for the fine-tuned model on the val split."""
    import torch
    from PIL import Image
    from torchmetrics.detection.mean_ap import MeanAveragePrecision
    from transformers import AutoImageProcessor, AutoModelForObjectDetection

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(_wrap_report(
        "<h2>Evaluation</h2><p>Loading val split and scoring model...</p>"
    ), do_flush=True)

    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    val_json = os.path.join(data_path, "val.json")

    with open(val_json) as f:
        val_coco = json.load(f)

    images_by_id = {im["id"]: im for im in val_coco["images"]}
    anns_by_image: dict[int, list] = {}
    for a in val_coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    pil_images = []
    targets = []
    for img_id, meta in images_by_id.items():
        path = os.path.join(images_root, os.path.basename(meta["file_name"]))
        if not os.path.exists(path):
            path = os.path.join(images_root, meta["file_name"])
        pil_images.append(Image.open(path).convert("RGB"))
        boxes_xyxy = []
        labels = []
        for a in anns_by_image.get(img_id, []):
            x, y, w, h = a["bbox"]
            boxes_xyxy.append([x, y, x + w, y + h])
            labels.append(a["category_id"])
        targets.append(
            {
                "boxes": torch.tensor(boxes_xyxy, dtype=torch.float32).reshape(-1, 4),
                "labels": torch.tensor(labels, dtype=torch.long),
            }
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    ft_path = await finetuned_dir.download()
    log.info(f"Scoring fine-tuned model: {ft_path}")
    processor = AutoImageProcessor.from_pretrained(ft_path)
    model = AutoModelForObjectDetection.from_pretrained(ft_path).to(device)
    preds = _run_inference(model, processor, pil_images, device, threshold=threshold)

    formatted_preds = [
        {"boxes": p["boxes"], "scores": p["scores"], "labels": p["labels"]}
        for p in preds
    ]

    metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
    metric.update(formatted_preds, targets)

    def to_python(v):
        if hasattr(v, "numel"):
            return v.item() if v.numel() == 1 else v.tolist()
        return v

    ft_metrics = {k: to_python(v) for k, v in metric.compute().items()}
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    log.info(f"Fine-tuned mAP: {ft_metrics.get('map', 0):.3f}")

    metric_keys = ["map", "map_50", "map_75", "mar_10"]
    metric_display = {
        "map": "mAP",
        "map_50": "mAP@50",
        "map_75": "mAP@75",
        "mar_10": "mAR@10",
    }

    rows = []
    for key in metric_keys:
        ft_val = ft_metrics.get(key, 0)
        rows.append(
            f"<tr><td><b>{metric_display.get(key, key)}</b></td>"
            f"<td class='highlight'>{ft_val:.3f}</td></tr>"
        )
    table = (
        "<table><tr><th>Metric</th><th>Score</th></tr>"
        + "".join(rows)
        + "</table>"
    )

    bar_chart = _make_bar_chart(
        labels=[metric_display.get(k, k) for k in metric_keys],
        series={"Fine-tuned": [ft_metrics.get(k, 0) for k in metric_keys]},
        title="COCO Evaluation Metrics",
        colors=["#0f3460"],
        y_max_cap=1.0,
    )

    ft_map = ft_metrics.get("map", 0)
    ft_map50 = ft_metrics.get("map_50", 0)

    eval_html = f"""
    <h2>Evaluation — COCO mAP</h2>
    <div class="stat-grid">
      <div class="stat"><div class="value">{len(pil_images)}</div><div class="label">Val Images</div></div>
      <div class="stat"><div class="value">{threshold}</div><div class="label">Threshold</div></div>
      <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
      <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
    </div>
    <div class="chart-container">{bar_chart}</div>
    {table}
    <div class="note">
      <b>mAP</b> (mean Average Precision) measures how accurately the model
      detects objects — balancing whether predictions are correct (precision)
      and whether all objects are found (recall). The @50 and @75 variants
      require IoU overlaps of 50% and 75% between predicted and ground-truth
      boxes. <b>mAR</b> (mean Average Recall) measures how many ground-truth
      objects the model finds, with @1 and @10 limiting detections to 1 or 10
      per image.
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(eval_html), do_flush=True)

    return json.dumps(
        {
            "finetuned": {k: round(v, 4) for k, v in ft_metrics.items() if isinstance(v, (int, float))},
            "num_val_images": len(pil_images),
        }
    )

# ------------------------------------------------------------------
# Task 4: Inference demo — render bboxes on val images
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def inference_demo(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    threshold: float = 0.5,
    max_images: int = 8,
    metrics_json: str = "{}",
) -> str:
    """Run the fine-tuned model on val images, render bboxes, embed in the report."""
    import torch
    from PIL import Image
    from torchmetrics.detection.mean_ap import MeanAveragePrecision
    from transformers import AutoImageProcessor, AutoModelForObjectDetection

    data_path = await data_dir.download()
    images_root = os.path.join(data_path, "images")
    val_json = os.path.join(data_path, "val.json")

    with open(val_json) as f:
        val_coco = json.load(f)

    id2label = {c["id"]: c["name"] for c in val_coco["categories"]}
    metas = val_coco["images"][:max_images]
    anns_by_image: dict[int, list] = {}
    for a in val_coco["annotations"]:
        anns_by_image.setdefault(a["image_id"], []).append(a)

    pil_images = []
    gt_per_image = []
    for meta in metas:
        path = os.path.join(images_root, os.path.basename(meta["file_name"]))
        if not os.path.exists(path):
            path = os.path.join(images_root, meta["file_name"])
        pil_images.append(Image.open(path).convert("RGB"))

        boxes_xyxy = []
        labels = []
        for a in anns_by_image.get(meta["id"], []):
            x, y, w, h = a["bbox"]
            boxes_xyxy.append([x, y, x + w, y + h])
            labels.append(a["category_id"])
        gt_per_image.append(
            {
                "boxes": torch.tensor(boxes_xyxy, dtype=torch.float32).reshape(-1, 4),
                "labels": torch.tensor(labels, dtype=torch.long),
                "scores": torch.ones(len(labels)),
            }
        )

    ft_path = await finetuned_dir.download()
    processor = AutoImageProcessor.from_pretrained(ft_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForObjectDetection.from_pretrained(ft_path).to(device)

    preds = _run_inference(model, processor, pil_images, device, threshold=threshold)

    html_blocks = []
    total_gt = 0
    total_pred = 0
    for i, (img, pred, gt) in enumerate(zip(pil_images, preds, gt_per_image)):
        n_gt = len(gt["labels"])
        n_pred = len(pred["labels"])
        total_gt += n_gt
        total_pred += n_pred

        # Per-image mAP
        metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")
        metric.update(
            [{"boxes": pred["boxes"], "scores": pred["scores"], "labels": pred["labels"]}],
            [{"boxes": gt["boxes"], "labels": gt["labels"]}],
        )
        img_metrics = metric.compute()
        img_map = img_metrics["map"].item()
        img_map_badge = (
            f'<span class="badge badge-success">mAP {img_map:.2f}</span>'
            if img_map >= 0.5
            else f'<span class="badge badge-info">mAP {img_map:.2f}</span>'
        )

        pred_img = _draw_boxes(
            img, pred["boxes"], pred["labels"], pred["scores"],
            id2label, color="lime",
        )
        gt_img = _draw_boxes(
            img, gt["boxes"], gt["labels"], gt["scores"],
            id2label, color="dodgerblue",
        )
        html_blocks.append(f"""
        <div class="card">
          <b>Image {i + 1}</b> {img_map_badge}
          <div class="img-pair">
            <div>
              <p><span class="gt-label">Ground Truth</span>
                 <span class="badge badge-info">{n_gt} boxes</span></p>
              <img src="{_img_to_data_uri(gt_img)}" />
            </div>
            <div>
              <p><span class="pred-label">Predictions</span>
                 <span class="badge badge-success">{n_pred} boxes</span>
                 (threshold={threshold})</p>
              <img src="{_img_to_data_uri(pred_img)}" />
            </div>
          </div>
        </div>""")

    # Parse metrics if provided (from evaluate task)
    metrics = json.loads(metrics_json)
    ft_metrics = metrics.get("finetuned", {})
    ft_map = ft_metrics.get("map", None)
    ft_map50 = ft_metrics.get("map_50", None)

    metrics_stats = ""
    if ft_map is not None:
        metrics_stats = f"""
        <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
        <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
        """

    demo_html = f"""
    <h2>Inference Demo</h2>
    <h3>Fine-tuned RT-DETR on validation images</h3>
    <div class="stat-grid">
      {metrics_stats}
      <div class="stat"><div class="value">{len(pil_images)}</div><div class="label">Images Shown</div></div>
      <div class="stat"><div class="value">{total_gt}</div><div class="label">Ground Truth Boxes</div></div>
      <div class="stat"><div class="value">{total_pred}</div><div class="label">Predicted Boxes</div></div>
      <div class="stat"><div class="value">{threshold}</div><div class="label">Confidence Threshold</div></div>
    </div>
    <p><span class="gt-label">Blue = ground truth</span> |
       <span class="pred-label">Green = predictions</span></p>
    {"".join(html_blocks)}
    """

    await flyte.report.replace.aio(_wrap_report(demo_html), do_flush=True)

    return json.dumps(
        {
            "num_images": len(pil_images),
            "predictions_per_image": [len(p["labels"]) for p in preds],
        }
    )

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "PekingU/rtdetr_v2_r18vd",
    dataset_repo: str = "sagecodes/union_flyte_swag_object_detection",
    annotations_path: str = "swag/train.json",
    images_subdir: str = "swag/images",
    epochs: int = 30,
    lr: float = 5e-5,
    batch_size: int = 4,
    val_fraction: float = 0.2,
    threshold: float = 0.5,
    demo_images: int = 8,
    eval_every_n_epochs: int | None = None,
) -> tuple[flyte.io.Dir, str]:
    """
    End-to-end RT-DETRv2 fine-tuning pipeline.

    Returns the fine-tuned model directory and a JSON summary.

    1. Download COCO dataset from HuggingFace and split train/val
    2. Fine-tune RT-DETRv2 on the train split
    3. Evaluate: COCO mAP comparison (base vs fine-tuned)
    4. Inference demo: render bounding boxes on val images
    """
    log.info(f"Pipeline: {model_name} | dataset={dataset_repo}")

    def _pipeline_progress(step: int, label: str) -> str:
        steps = ["Preparing Data", "Fine-tuning", "Evaluating", "Inference Demo"]
        dots = ""
        for i, s in enumerate(steps):
            if i + 1 < step:
                icon = '<span style="color:#06d6a0;">&#10003;</span>'
            elif i + 1 == step:
                icon = '<span style="color:#e94560;">&#9679;</span>'
            else:
                icon = '<span style="color:#adb5bd;">&#9675;</span>'
            dots += f"<span style='margin:0 8px;'>{icon} {s}</span>"
        return f"""
        <h2>RT-DETRv2 Object Detection Pipeline</h2>
        <p><b>Model:</b> {model_name} | <b>Dataset:</b> {dataset_repo}</p>
        <div class="card" style="text-align:center;">{dots}</div>
        <p>{label}</p>
        """

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(1, "Downloading and splitting dataset...")),
        do_flush=True,
    )

    data_dir = await prepare_data(
        dataset_repo=dataset_repo,
        annotations_path=annotations_path,
        images_subdir=images_subdir,
        val_fraction=val_fraction,
    )

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(2, "Fine-tuning model...")),
        do_flush=True,
    )

    finetuned_dir = await train(
        model_name, data_dir, epochs, lr, batch_size,
        eval_every_n_epochs=eval_every_n_epochs,
    )

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(3, "Running COCO mAP evaluation...")),
        do_flush=True,
    )

    metrics_json = await evaluate(finetuned_dir, data_dir, threshold)
    metrics = json.loads(metrics_json)

    await flyte.report.replace.aio(
        _wrap_report(_pipeline_progress(4, "Rendering bounding box demo...")),
        do_flush=True,
    )

    demo_json = await inference_demo(
        finetuned_dir, data_dir, threshold, demo_images,
        metrics_json=metrics_json,
    )

    ft_map = metrics["finetuned"].get("map", 0)
    ft_map50 = metrics["finetuned"].get("map_50", 0)

    final_html = f"""
    <h2>Pipeline Complete</h2>
    <h3>{model_name}</h3>
    <div class="stat-grid">
      <div class="stat"><div class="value">{metrics['num_val_images']}</div><div class="label">Val Images</div></div>
      <div class="stat"><div class="value highlight">{ft_map:.3f}</div><div class="label">mAP</div></div>
      <div class="stat"><div class="value highlight">{ft_map50:.3f}</div><div class="label">mAP@50</div></div>
    </div>
    <div class="card">
      <b>Configuration:</b> {epochs} epochs | LR {lr} | Batch size {batch_size} |
      Val fraction {val_fraction} | Threshold {threshold}
    </div>
    """

    await flyte.report.replace.aio(_wrap_report(final_html), do_flush=True)

    log.info(f"Pipeline complete. Fine-tuned mAP: {ft_map:.3f}")
    return finetuned_dir, json.dumps({"metrics": metrics, "demo": json.loads(demo_json)})

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/detr_object_detection/detr_object_detection.py*

## Run the workflow

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/detr_object_detection):

```
cd v2/tutorials/detr_object_detection
uv run --script detr_object_detection.py
```

Quick local smoke test with one epoch:

```
flyte run detr_object_detection.py pipeline --epochs 1 --batch_size 2
```

This workflow needs a GPU. Check the **train**, **evaluate**, and **inference_demo** task reports for charts and annotated images.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents ===

# Agents

Tutorials for building agentic workflows and autonomous LLM-powered systems.

### **Agents > Autoresearch agent**

Run an autonomous research loop that drives Claude Code in a GPU container to run experiments, then commits results and opens a pull request.

### **Agents > Parallelized autoresearch agent**

Scale autoresearch with a code-mode MLE agent that batches train.py edits and runs sandbox experiments in parallel via flyte.map.

### **Agents > AutoSec researcher agent**

Fan out vulnerability analysis across C targets, hypothesize exploits with an LLM agent, and validate PoCs in an isolated sandbox.

### **Agents > Coding agent**

Securely execute and iterate on LLM-generated code using a code agent with error reflection and retry logic.

### **Agents > Competitive intelligence agent**

Fan out across competitors, extract source-cited market deltas with the You.com Search API, and build a knowledge-graph-ready intelligence table.

### **Agents > Compliance monitoring agent**

Monitor trusted regulatory sources with the You.com Research API and route citation-precise findings to the right team.

### **Agents > Deep research**

Build an agentic workflow for deep research with multi-step reasoning and evaluation.

### **Agents > LangGraph research agent**

Combine LangGraph control flow with Flyte tasks for multi-topic web research with quality-check loops.

### **Agents > Field data enrichment agent**

Enrich geo-tagged operational events with real-world public context using the You.com Search API with country and freshness targeting.

### **Agents > MLE bot: an autonomous ML engineer**

An autonomous ML agent that designs, runs, and iterates on experiments using Flyte's durable sandbox for safe LLM-generated code execution.

### **Agents > Support resolution agent**

Ground support tickets in fresh public sources via the You.com Research API and draft cited, customer-ready replies for human review.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/autoresearch ===

# Autoresearch agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/autoresearch).

This tutorial wraps an autonomous AI research loop in a single Flyte task. The task spins up a GPU container, installs the [Claude Code](https://docs.anthropic.com/en/docs/claude-code/overview) CLI, clones a research repository, and points Claude Code at a `program.md` brief. The agent runs experiments to improve a model, writes results to disk, and the task then commits the changes and opens a pull request — with a progress plot rendered both in the PR and in the Flyte UI.

It's an example of using Flyte as durable infrastructure for long-running, autonomous agent work:

- **A GPU `TaskEnvironment`** with the API-key and GitHub secrets the agent needs.
- **`report=True`** to stream a progress plot into the Flyte UI.
- **A reconnecting `run.wait()`** loop in the driver so a dropped client connection doesn't lose track of a multi-hour run.

> [!WARNING]
> This example drives a coding agent that executes arbitrary code and pushes commits to a GitHub repository. Run it against a repository you control, and review the constants described below before launching.

## Define the container image

The image is kept in its own `_image.py` module so edits to the agent logic in `run.py` don't invalidate the image cache. Node.js and the Claude Code CLI are installed at run time (see below) to keep the image small.

```
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "flyte>=2.0.0b22",
#     "PyGithub>=2.5.0",
#     "matplotlib>=3.7.0",
#     "pandas>=2.0.0",
# ]
# ///
#
# Stable image definition — kept separate from run.py so edits to run.py
# don't invalidate the image cache. Only touch this file when the image itself needs to change.

import flyte

# {{docs-fragment image}}
image = (
    flyte.Image.from_uv_script(__file__, name="autoresearch-agent", pre=True)
    .with_apt_packages("git")
)
# {{/docs-fragment image}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autoresearch/_image.py*

## Define the task environment

The task needs a GPU, a generous disk for the cloned repo and model weights, and two secrets: a GitHub token (to clone and push) and an Anthropic API key (for Claude Code).

```
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "flyte>=2.0.0b22",
#     "PyGithub>=2.5.0",
#     "matplotlib>=3.7.0",
# ]
# ///

"""
AutoResearch Agent - Runs the autoresearch workflow using Claude Code CLI in a GPU environment.

This agent:
1. Starts a GPU-enabled container
2. Installs Claude Code CLI
3. Clones the autoresearch repository
4. Points Claude Code at program.md as the prompt and lets it run
5. Commits the result (CSV + code changes in train/) and creates a PR
"""

import os
import shlex
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from github import Auth, Github

import flyte
import flyte.report
from _image import image as autoresearch_image

GITHUB_USERNAME = "parnianz"
GITHUB_EMAIL = "parnianzargham@gmail.com"
AUTORESEARCH_REPO_URL = "https://github.com/unionai-oss/autoresearch.git"
AUTORESEARCH_REPO_FULL_NAME = "unionai-oss/autoresearch"

# {{docs-fragment env}}
autoresearch_env = flyte.TaskEnvironment(
    name="autoresearch-agent",
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    secrets=[
        flyte.Secret(key="github_token", as_env_var="GITHUB_TOKEN"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=autoresearch_image,
)
# {{/docs-fragment env}}

# {{docs-fragment result}}
@dataclass
class AutoResearchResult:
    """Result of the autoresearch run."""

    pr_url: str
    pr_number: int
    branch_name: str
    files_changed: list[str]
    success: bool
    error_message: Optional[str] = None
# {{/docs-fragment result}}

def clone_repository(repo_url: str, work_dir: Path, github_token: str) -> Path:
    """Clone the autoresearch repository with authentication."""
    repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "")
    repo_path = work_dir / repo_name

    # Inject token into HTTPS URL for authentication
    authenticated_url = repo_url.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )

    if repo_path.exists():
        subprocess.run(["git", "pull"], cwd=repo_path, check=True)
    else:
        subprocess.run(["git", "clone", authenticated_url, str(repo_path)], check=True)

    return repo_path

# {{docs-fragment task}}
@autoresearch_env.task(report=True)
async def run_autoresearch() -> AutoResearchResult:
    """
    Run the autoresearch workflow end-to-end.

    Steps:
    - Clone https://github.com/unionai-oss/autoresearch
    - Configure git identity
    - Create a new branch
    - Run Claude Code CLI with program.md as the prompt
    - Commit results (CSV + train/ changes)
    - Push and open a PR against the autoresearch repo
    """
    github_token = os.environ["GITHUB_TOKEN"]
    anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]

    # --- Install Node.js + Claude Code at runtime (keeps image small and submission fast) ---
    import tarfile
    import urllib.request as _urllib

    subprocess.run(["apt-get", "update", "-y"], check=False)
    subprocess.run(["apt-get", "install", "-y", "git"], check=False)

    node_url = "https://nodejs.org/dist/v20.19.0/node-v20.19.0-linux-x64.tar.gz"
    node_tar = Path("/tmp/node.tar.gz")
    print(f"Downloading Node.js from {node_url}...", flush=True)
    _urllib.urlretrieve(node_url, node_tar)
    size_mb = node_tar.stat().st_size / 1024 / 1024
    print(f"Downloaded {size_mb:.1f} MB to {node_tar}", flush=True)
    if size_mb < 1:
        raise RuntimeError(f"Node.js download appears empty/corrupt ({size_mb:.2f} MB) — network may be restricted")
    node_dir = Path("/tmp/node")
    node_dir.mkdir(exist_ok=True)
    print("Extracting Node.js...", flush=True)
    with tarfile.open(node_tar, "r:gz") as tar:
        members = [m for m in tar.getmembers() if m.name.split("/", 1)[-1]]
        for m in members:
            m.name = m.name.split("/", 1)[-1]
        tar.extractall(str(node_dir), members=[m for m in members if m.name])

    # Add node/npm to PATH for this process and all subprocesses
    node_bin = str(node_dir / "bin")
    os.environ["PATH"] = node_bin + ":" + os.environ.get("PATH", "")
    print(f"Node version: {subprocess.run(['node', '--version'], capture_output=True, text=True).stdout.strip()}", flush=True)

    npm_prefix = "/tmp/npm-global"
    Path(npm_prefix).mkdir(exist_ok=True)
    subprocess.run(["npm", "install", "-g", "--prefix", npm_prefix, "@anthropic-ai/claude-code"], check=True)
    os.environ["PATH"] = str(Path(npm_prefix) / "bin") + ":" + os.environ["PATH"]
    print("Node.js + Claude Code installed.", flush=True)

    # --- Clone repo ---
    work_dir = Path("/tmp/autoresearch_workspace")
    work_dir.mkdir(exist_ok=True, parents=True)
    repo_path = clone_repository(AUTORESEARCH_REPO_URL, work_dir, github_token)

    # --- Git identity ---
    subprocess.run(
        ["git", "config", "--global", "user.email", GITHUB_EMAIL], check=True
    )
    subprocess.run(
        ["git", "config", "--global", "user.name", GITHUB_USERNAME], check=True
    )

    # --- Create branch ---
    import time as _time
    branch_name = f"autoresearch/claude-run-{int(_time.time())}"
    try:
        subprocess.run(
            ["git", "checkout", "-b", branch_name],
            cwd=repo_path,
            check=True,
        )
    except subprocess.CalledProcessError:
        subprocess.run(
            ["git", "checkout", branch_name],
            cwd=repo_path,
            check=True,
        )

    # --- Read program.md to use as the Claude Code prompt ---
    program_md = repo_path / "program.md"
    if not program_md.exists():
        raise FileNotFoundError(
            f"program.md not found in {repo_path}. "
            "Make sure the autoresearch repo has a program.md at its root."
        )

    program_md_content = program_md.read_text()
    print(f"Loaded prompt from program.md ({len(program_md_content)} chars)")
    # {{/docs-fragment task}}

    # Install repo dependencies before handing off to Claude
    for pip_cmd in [
        ["pip", "install", "-e", "."],
        ["pip", "install", "-r", "requirements.txt"],
    ]:
        req_file = repo_path / pip_cmd[-1] if pip_cmd[-1].startswith("req") else None
        if req_file is None or req_file.exists():
            dep_result = subprocess.run(
                pip_cmd, cwd=repo_path, capture_output=True, text=True
            )
            print(f"{' '.join(pip_cmd)}:\n{dep_result.stdout}", flush=True)
            if dep_result.returncode != 0:
                print(f"(non-fatal) {dep_result.stderr}", flush=True)

    # Wrap the program.md content with explicit instructions to write outputs to disk
    prompt = f"""You are running inside an automated GPU pipeline. You MUST write all outputs to disk as actual files.

Here are your instructions from program.md:

{program_md_content}

LOGGING INSTRUCTIONS (follow exactly):
- Before you start any training, print this exact line: [AUTORESEARCH] Training started
- Before training, print what change you are testing: [AUTORESEARCH] Change: <one line description of the code change being tested>
- When training finishes, print this exact line: [AUTORESEARCH] Training finished
- After training, print the key metric value: [AUTORESEARCH] Metric: <metric name>=<value>
- When writing results to CSV, print this exact line: [AUTORESEARCH] Writing results to CSV

IMPORTANT: After completing the above instructions, make sure you have:
1. Written the final results to a CSV file in this repository (e.g. results/results.csv or similar)
2. Saved all code changes you made to the train/ directory (or wherever the training code lives)
3. All files must be written to the current working directory so they appear in git status
If any command fails, debug and fix it rather than stopping. Do not just print results — write them to files on disk."""

    # --- Pre-flight: verify claude is installed and API key is reachable ---
    version_check = subprocess.run(
        ["claude", "--version"], capture_output=True, text=True
    )
    print(f"claude version: {version_check.stdout.strip()} | stderr: {version_check.stderr.strip()}", flush=True)
    if version_check.returncode != 0:
        raise RuntimeError(f"claude CLI not found or broken: {version_check.stderr}")

    # --- Disable Claude Code sandbox ---
    # In Kubernetes/Flyte pods, Claude Code's sandbox tries to spin up a nested container
    # which fails silently and causes file writes to go to an ephemeral space instead of
    # the real working directory. Disabling it makes writes land in the actual filesystem.
    claude_config_dir = Path("/root/.claude")
    claude_config_dir.mkdir(parents=True, exist_ok=True)
    settings = claude_config_dir / "settings.json"
    import json as _json
    existing = _json.loads(settings.read_text()) if settings.exists() else {}
    existing["sandbox"] = False
    settings.write_text(_json.dumps(existing, indent=2))
    print(f"Wrote Claude Code settings: {settings.read_text()}", flush=True)

    # --- Run Claude Code CLI ---
    # Matches swe_agent.py exactly: prompt as positional arg, CI=true enables non-interactive mode
    cmd = [
        "claude",
        "--dangerously-skip-permissions",
        "--max-turns", "100",
        "--model", "claude-haiku-4-5-20251001",
        prompt,
    ]

    print(f"Running: {shlex.join(cmd[:3])} <prompt>", flush=True)

    claude_env = {
        **os.environ,
        "ANTHROPIC_API_KEY": anthropic_api_key,
        "CLAUDE_SKIP_PERMISSIONS": "true",
        "CI": "true",  # Enables non-interactive mode (no TTY required)
    }

    # Stream output line by line so logs appear in real time instead of buffering until done
    proc = subprocess.Popen(
        cmd,
        cwd=repo_path,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # merge stderr into stdout stream
        text=True,
        env=claude_env,
    )

    stdout_lines = []
    for line in proc.stdout:
        line = line.rstrip("\n")
        print(line, flush=True)
        stdout_lines.append(line)

    proc.wait()
    full_output = "\n".join(stdout_lines)
    print(f"Claude Code exit code: {proc.returncode}", flush=True)

    if proc.returncode != 0:
        raise RuntimeError(
            f"Claude Code CLI exited with code {proc.returncode}\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Collect changed files ---
    git_status = subprocess.run(
        ["git", "status", "--porcelain"],
        cwd=repo_path,
        capture_output=True,
        text=True,
        check=True,
    )

    print(f"Git status:\n{git_status.stdout}", flush=True)

    files_changed = []
    for line in git_status.stdout.strip().splitlines():
        if line:
            # git status --porcelain: first two chars are XY status flags
            file_path = line[3:].strip()
            files_changed.append(file_path)

    # Also list all files in repo dir for debugging
    all_files = subprocess.run(
        ["find", ".", "-type", "f", "-not", "-path", "./.git/*"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"All files in repo:\n{all_files.stdout}", flush=True)

    if not files_changed:
        raise RuntimeError(
            "Claude Code ran successfully but produced no file changes.\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Commit ---
    subprocess.run(["git", "add", "."], cwd=repo_path, check=True)
    subprocess.run(["git", "add", "-f", "results.tsv"], cwd=repo_path, check=False)
    subprocess.run(["git", "add", "-f", "results/"], cwd=repo_path, check=False)
    commit_message = (
        "feat: autoresearch run via Claude Code\n\n"
        "Added research results (CSV) and updated train/ code changes.\n"
        "Generated by the autoresearch Flyte agent."
    )
    subprocess.run(
        ["git", "commit", "-m", commit_message],
        cwd=repo_path,
        check=True,
    )

    # --- Push ---
    print(f"GitHub token present: {bool(github_token)}, length: {len(github_token) if github_token else 0}", flush=True)
    authenticated_url = AUTORESEARCH_REPO_URL.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )
    subprocess.run(
        ["git", "remote", "set-url", "origin", authenticated_url],
        cwd=repo_path,
        check=True,
    )
    push_result = subprocess.run(
        ["git", "push", "-u", "origin", branch_name, "--force"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"Push stdout: {push_result.stdout}", flush=True)
    print(f"Push stderr: {push_result.stderr}", flush=True)
    if push_result.returncode != 0:
        raise RuntimeError(f"git push failed (exit {push_result.returncode}):\n{push_result.stderr}")

    # --- Create PR via PyGithub ---
    auth = Auth.Token(github_token)
    gh = Github(auth=auth)
    repo = gh.get_repo(AUTORESEARCH_REPO_FULL_NAME)

    csv_files = [f for f in files_changed if f.endswith(".csv")]
    train_files = [f for f in files_changed if "train" in f]

    pr_body = f"""## AutoResearch Run

This PR was automatically generated by the autoresearch Flyte agent using Claude Code CLI.

### What changed
- **Result CSV files**: {', '.join(f'`{f}`' for f in csv_files) or 'none detected'}
- **Train code changes**: {', '.join(f'`{f}`' for f in train_files) or 'none detected'}

### All changed files
{chr(10).join(f'- `{f}`' for f in files_changed)}

---
🤖 Generated by [autoresearch Flyte agent](https://github.com/unionai-oss/autoresearch)
"""

    existing_prs = list(repo.get_pulls(state="open", head=f"unionai-oss:{branch_name}"))
    if existing_prs:
        pr = existing_prs[0]
        print(f"PR already exists: {pr.html_url}", flush=True)
    else:
        pr = repo.create_pull(
            title="feat: autoresearch results + train changes",
            body=pr_body,
            head=branch_name,
            base="master",
        )
        print(f"PR created: {pr.html_url}", flush=True)

    # --- Generate progress plot from results.tsv ---
    plot_path = repo_path / "progress.png"
    results_tsv = repo_path / "results.tsv"
    if results_tsv.exists():
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        import pandas as pd

        df = pd.read_csv(str(results_tsv), sep="\t")
        df["val_bpb"] = pd.to_numeric(df["val_bpb"], errors="coerce")
        df["memory_gb"] = pd.to_numeric(df["memory_gb"], errors="coerce")
        df["status"] = df["status"].str.strip().str.upper()

        # Filter out crashes for plotting
        valid = df[df["status"] != "CRASH"].copy()
        valid = valid.reset_index(drop=True)

        if len(valid) > 0 and valid["val_bpb"].notna().any():
            baseline_bpb = valid.loc[0, "val_bpb"]
            best = valid["val_bpb"].min()

            # Only plot points at or below baseline (the interesting region)
            below = valid[valid["val_bpb"] <= baseline_bpb + 0.0005]

            fig, ax = plt.subplots(figsize=(16, 8))

            # Plot discarded as faint background dots
            disc = below[below["status"] == "DISCARD"]
            ax.scatter(disc.index, disc["val_bpb"],
                       c="#cccccc", s=12, alpha=0.5, zorder=2, label="Discarded")

            # Plot kept experiments as prominent green dots
            kept_v = below[below["status"] == "KEEP"]
            ax.scatter(kept_v.index, kept_v["val_bpb"],
                       c="#2ecc71", s=50, zorder=4, label="Kept", edgecolors="black", linewidths=0.5)

            # Running minimum step line
            kept_mask = valid["status"] == "KEEP"
            kept_idx = valid.index[kept_mask]
            kept_bpb = valid.loc[kept_mask, "val_bpb"]
            running_min = kept_bpb.cummin()
            ax.step(kept_idx, running_min, where="post", color="#27ae60",
                    linewidth=2, alpha=0.7, zorder=3, label="Running best")

            # Label each kept experiment with its description
            for idx, bpb in zip(kept_idx, kept_bpb):
                desc = str(valid.loc[idx, "description"]).strip()
                if len(desc) > 45:
                    desc = desc[:42] + "..."
                ax.annotate(desc, (idx, bpb),
                            textcoords="offset points",
                            xytext=(6, 6), fontsize=8.0,
                            color="#1a7a3a", alpha=0.9,
                            rotation=30, ha="left", va="bottom")

            n_total = len(df)
            n_kept = len(df[df["status"] == "KEEP"])
            ax.set_xlabel("Experiment #", fontsize=12)
            ax.set_ylabel("Validation BPB (lower is better)", fontsize=12)
            ax.set_title(f"Autoresearch Progress: {n_total} Experiments, {n_kept} Kept Improvements", fontsize=14)
            ax.legend(loc="upper right", fontsize=9)
            ax.grid(True, alpha=0.2)

            margin = (baseline_bpb - best) * 0.15
            ax.set_ylim(best - margin, baseline_bpb + margin)

            plt.tight_layout()
            plt.savefig(str(plot_path), dpi=150, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved plot to {plot_path}", flush=True)

            # Upload plot to PR as a comment with base64 inline image
            import base64
            img_b64 = base64.b64encode(plot_path.read_bytes()).decode()
            pr_comment = (
                "## Autoresearch Progress\n\n"
                f"![Autoresearch Progress](data:image/png;base64,{img_b64})"
            )
            pr.create_issue_comment(pr_comment)
            print("Posted plot as PR comment.", flush=True)

            # Force-add plot to git and amend commit
            subprocess.run(["git", "add", "-f", str(plot_path)], cwd=repo_path, check=False)
            subprocess.run(
                ["git", "commit", "--amend", "--no-edit"],
                cwd=repo_path, check=False,
            )
            subprocess.run(
                ["git", "push", "-u", "origin", branch_name, "--force"],
                cwd=repo_path, check=False,
            )

            # Show plot in Flyte UI via report
            await flyte.report.replace.aio(
                f"<h2>Autoresearch Progress</h2>"
                f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
                f'<p><a href="{pr.html_url}">View PR</a></p>'
            )
            await flyte.report.flush.aio()
        else:
            print("results.tsv found but no valid val_bpb rows — skipping plot.", flush=True)
    else:
        print("results.tsv not found — skipping plot.", flush=True)

    return AutoResearchResult(
        pr_url=pr.html_url,
        pr_number=pr.number,
        branch_name=branch_name,
        files_changed=files_changed,
        success=True,
    )

# {{docs-fragment main}}
if __name__ == "__main__":
    import time

    flyte.init_from_config()

    run = flyte.with_runcontext(mode="remote").run(run_autoresearch)

    print(f"AutoResearch run started: {run.url}")
    print("Waiting for completion...")

    while True:
        try:
            run.wait()
            break
        except Exception as e:
            print(f"Connection dropped ({e}), reconnecting in 30s...")
            time.sleep(30)

    print(f"Done! See run at: {run.url}")
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autoresearch/run.py*

The agent targets a specific repository, identity, and branch via module-level constants. Update these to point at your own fork before running:

```
GITHUB_USERNAME = "<YOUR_GITHUB_USERNAME>"
GITHUB_EMAIL = "you@example.com"
AUTORESEARCH_REPO_URL = "https://github.com/<YOUR_ORG>/<YOUR_REPO>.git"
AUTORESEARCH_REPO_FULL_NAME = "<YOUR_ORG>/<YOUR_REPO>"
```

## Model the result

The task returns a typed result describing the pull request it created.

```
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "flyte>=2.0.0b22",
#     "PyGithub>=2.5.0",
#     "matplotlib>=3.7.0",
# ]
# ///

"""
AutoResearch Agent - Runs the autoresearch workflow using Claude Code CLI in a GPU environment.

This agent:
1. Starts a GPU-enabled container
2. Installs Claude Code CLI
3. Clones the autoresearch repository
4. Points Claude Code at program.md as the prompt and lets it run
5. Commits the result (CSV + code changes in train/) and creates a PR
"""

import os
import shlex
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from github import Auth, Github

import flyte
import flyte.report
from _image import image as autoresearch_image

GITHUB_USERNAME = "parnianz"
GITHUB_EMAIL = "parnianzargham@gmail.com"
AUTORESEARCH_REPO_URL = "https://github.com/unionai-oss/autoresearch.git"
AUTORESEARCH_REPO_FULL_NAME = "unionai-oss/autoresearch"

# {{docs-fragment env}}
autoresearch_env = flyte.TaskEnvironment(
    name="autoresearch-agent",
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    secrets=[
        flyte.Secret(key="github_token", as_env_var="GITHUB_TOKEN"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=autoresearch_image,
)
# {{/docs-fragment env}}

# {{docs-fragment result}}
@dataclass
class AutoResearchResult:
    """Result of the autoresearch run."""

    pr_url: str
    pr_number: int
    branch_name: str
    files_changed: list[str]
    success: bool
    error_message: Optional[str] = None
# {{/docs-fragment result}}

def clone_repository(repo_url: str, work_dir: Path, github_token: str) -> Path:
    """Clone the autoresearch repository with authentication."""
    repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "")
    repo_path = work_dir / repo_name

    # Inject token into HTTPS URL for authentication
    authenticated_url = repo_url.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )

    if repo_path.exists():
        subprocess.run(["git", "pull"], cwd=repo_path, check=True)
    else:
        subprocess.run(["git", "clone", authenticated_url, str(repo_path)], check=True)

    return repo_path

# {{docs-fragment task}}
@autoresearch_env.task(report=True)
async def run_autoresearch() -> AutoResearchResult:
    """
    Run the autoresearch workflow end-to-end.

    Steps:
    - Clone https://github.com/unionai-oss/autoresearch
    - Configure git identity
    - Create a new branch
    - Run Claude Code CLI with program.md as the prompt
    - Commit results (CSV + train/ changes)
    - Push and open a PR against the autoresearch repo
    """
    github_token = os.environ["GITHUB_TOKEN"]
    anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]

    # --- Install Node.js + Claude Code at runtime (keeps image small and submission fast) ---
    import tarfile
    import urllib.request as _urllib

    subprocess.run(["apt-get", "update", "-y"], check=False)
    subprocess.run(["apt-get", "install", "-y", "git"], check=False)

    node_url = "https://nodejs.org/dist/v20.19.0/node-v20.19.0-linux-x64.tar.gz"
    node_tar = Path("/tmp/node.tar.gz")
    print(f"Downloading Node.js from {node_url}...", flush=True)
    _urllib.urlretrieve(node_url, node_tar)
    size_mb = node_tar.stat().st_size / 1024 / 1024
    print(f"Downloaded {size_mb:.1f} MB to {node_tar}", flush=True)
    if size_mb < 1:
        raise RuntimeError(f"Node.js download appears empty/corrupt ({size_mb:.2f} MB) — network may be restricted")
    node_dir = Path("/tmp/node")
    node_dir.mkdir(exist_ok=True)
    print("Extracting Node.js...", flush=True)
    with tarfile.open(node_tar, "r:gz") as tar:
        members = [m for m in tar.getmembers() if m.name.split("/", 1)[-1]]
        for m in members:
            m.name = m.name.split("/", 1)[-1]
        tar.extractall(str(node_dir), members=[m for m in members if m.name])

    # Add node/npm to PATH for this process and all subprocesses
    node_bin = str(node_dir / "bin")
    os.environ["PATH"] = node_bin + ":" + os.environ.get("PATH", "")
    print(f"Node version: {subprocess.run(['node', '--version'], capture_output=True, text=True).stdout.strip()}", flush=True)

    npm_prefix = "/tmp/npm-global"
    Path(npm_prefix).mkdir(exist_ok=True)
    subprocess.run(["npm", "install", "-g", "--prefix", npm_prefix, "@anthropic-ai/claude-code"], check=True)
    os.environ["PATH"] = str(Path(npm_prefix) / "bin") + ":" + os.environ["PATH"]
    print("Node.js + Claude Code installed.", flush=True)

    # --- Clone repo ---
    work_dir = Path("/tmp/autoresearch_workspace")
    work_dir.mkdir(exist_ok=True, parents=True)
    repo_path = clone_repository(AUTORESEARCH_REPO_URL, work_dir, github_token)

    # --- Git identity ---
    subprocess.run(
        ["git", "config", "--global", "user.email", GITHUB_EMAIL], check=True
    )
    subprocess.run(
        ["git", "config", "--global", "user.name", GITHUB_USERNAME], check=True
    )

    # --- Create branch ---
    import time as _time
    branch_name = f"autoresearch/claude-run-{int(_time.time())}"
    try:
        subprocess.run(
            ["git", "checkout", "-b", branch_name],
            cwd=repo_path,
            check=True,
        )
    except subprocess.CalledProcessError:
        subprocess.run(
            ["git", "checkout", branch_name],
            cwd=repo_path,
            check=True,
        )

    # --- Read program.md to use as the Claude Code prompt ---
    program_md = repo_path / "program.md"
    if not program_md.exists():
        raise FileNotFoundError(
            f"program.md not found in {repo_path}. "
            "Make sure the autoresearch repo has a program.md at its root."
        )

    program_md_content = program_md.read_text()
    print(f"Loaded prompt from program.md ({len(program_md_content)} chars)")
    # {{/docs-fragment task}}

    # Install repo dependencies before handing off to Claude
    for pip_cmd in [
        ["pip", "install", "-e", "."],
        ["pip", "install", "-r", "requirements.txt"],
    ]:
        req_file = repo_path / pip_cmd[-1] if pip_cmd[-1].startswith("req") else None
        if req_file is None or req_file.exists():
            dep_result = subprocess.run(
                pip_cmd, cwd=repo_path, capture_output=True, text=True
            )
            print(f"{' '.join(pip_cmd)}:\n{dep_result.stdout}", flush=True)
            if dep_result.returncode != 0:
                print(f"(non-fatal) {dep_result.stderr}", flush=True)

    # Wrap the program.md content with explicit instructions to write outputs to disk
    prompt = f"""You are running inside an automated GPU pipeline. You MUST write all outputs to disk as actual files.

Here are your instructions from program.md:

{program_md_content}

LOGGING INSTRUCTIONS (follow exactly):
- Before you start any training, print this exact line: [AUTORESEARCH] Training started
- Before training, print what change you are testing: [AUTORESEARCH] Change: <one line description of the code change being tested>
- When training finishes, print this exact line: [AUTORESEARCH] Training finished
- After training, print the key metric value: [AUTORESEARCH] Metric: <metric name>=<value>
- When writing results to CSV, print this exact line: [AUTORESEARCH] Writing results to CSV

IMPORTANT: After completing the above instructions, make sure you have:
1. Written the final results to a CSV file in this repository (e.g. results/results.csv or similar)
2. Saved all code changes you made to the train/ directory (or wherever the training code lives)
3. All files must be written to the current working directory so they appear in git status
If any command fails, debug and fix it rather than stopping. Do not just print results — write them to files on disk."""

    # --- Pre-flight: verify claude is installed and API key is reachable ---
    version_check = subprocess.run(
        ["claude", "--version"], capture_output=True, text=True
    )
    print(f"claude version: {version_check.stdout.strip()} | stderr: {version_check.stderr.strip()}", flush=True)
    if version_check.returncode != 0:
        raise RuntimeError(f"claude CLI not found or broken: {version_check.stderr}")

    # --- Disable Claude Code sandbox ---
    # In Kubernetes/Flyte pods, Claude Code's sandbox tries to spin up a nested container
    # which fails silently and causes file writes to go to an ephemeral space instead of
    # the real working directory. Disabling it makes writes land in the actual filesystem.
    claude_config_dir = Path("/root/.claude")
    claude_config_dir.mkdir(parents=True, exist_ok=True)
    settings = claude_config_dir / "settings.json"
    import json as _json
    existing = _json.loads(settings.read_text()) if settings.exists() else {}
    existing["sandbox"] = False
    settings.write_text(_json.dumps(existing, indent=2))
    print(f"Wrote Claude Code settings: {settings.read_text()}", flush=True)

    # --- Run Claude Code CLI ---
    # Matches swe_agent.py exactly: prompt as positional arg, CI=true enables non-interactive mode
    cmd = [
        "claude",
        "--dangerously-skip-permissions",
        "--max-turns", "100",
        "--model", "claude-haiku-4-5-20251001",
        prompt,
    ]

    print(f"Running: {shlex.join(cmd[:3])} <prompt>", flush=True)

    claude_env = {
        **os.environ,
        "ANTHROPIC_API_KEY": anthropic_api_key,
        "CLAUDE_SKIP_PERMISSIONS": "true",
        "CI": "true",  # Enables non-interactive mode (no TTY required)
    }

    # Stream output line by line so logs appear in real time instead of buffering until done
    proc = subprocess.Popen(
        cmd,
        cwd=repo_path,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # merge stderr into stdout stream
        text=True,
        env=claude_env,
    )

    stdout_lines = []
    for line in proc.stdout:
        line = line.rstrip("\n")
        print(line, flush=True)
        stdout_lines.append(line)

    proc.wait()
    full_output = "\n".join(stdout_lines)
    print(f"Claude Code exit code: {proc.returncode}", flush=True)

    if proc.returncode != 0:
        raise RuntimeError(
            f"Claude Code CLI exited with code {proc.returncode}\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Collect changed files ---
    git_status = subprocess.run(
        ["git", "status", "--porcelain"],
        cwd=repo_path,
        capture_output=True,
        text=True,
        check=True,
    )

    print(f"Git status:\n{git_status.stdout}", flush=True)

    files_changed = []
    for line in git_status.stdout.strip().splitlines():
        if line:
            # git status --porcelain: first two chars are XY status flags
            file_path = line[3:].strip()
            files_changed.append(file_path)

    # Also list all files in repo dir for debugging
    all_files = subprocess.run(
        ["find", ".", "-type", "f", "-not", "-path", "./.git/*"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"All files in repo:\n{all_files.stdout}", flush=True)

    if not files_changed:
        raise RuntimeError(
            "Claude Code ran successfully but produced no file changes.\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Commit ---
    subprocess.run(["git", "add", "."], cwd=repo_path, check=True)
    subprocess.run(["git", "add", "-f", "results.tsv"], cwd=repo_path, check=False)
    subprocess.run(["git", "add", "-f", "results/"], cwd=repo_path, check=False)
    commit_message = (
        "feat: autoresearch run via Claude Code\n\n"
        "Added research results (CSV) and updated train/ code changes.\n"
        "Generated by the autoresearch Flyte agent."
    )
    subprocess.run(
        ["git", "commit", "-m", commit_message],
        cwd=repo_path,
        check=True,
    )

    # --- Push ---
    print(f"GitHub token present: {bool(github_token)}, length: {len(github_token) if github_token else 0}", flush=True)
    authenticated_url = AUTORESEARCH_REPO_URL.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )
    subprocess.run(
        ["git", "remote", "set-url", "origin", authenticated_url],
        cwd=repo_path,
        check=True,
    )
    push_result = subprocess.run(
        ["git", "push", "-u", "origin", branch_name, "--force"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"Push stdout: {push_result.stdout}", flush=True)
    print(f"Push stderr: {push_result.stderr}", flush=True)
    if push_result.returncode != 0:
        raise RuntimeError(f"git push failed (exit {push_result.returncode}):\n{push_result.stderr}")

    # --- Create PR via PyGithub ---
    auth = Auth.Token(github_token)
    gh = Github(auth=auth)
    repo = gh.get_repo(AUTORESEARCH_REPO_FULL_NAME)

    csv_files = [f for f in files_changed if f.endswith(".csv")]
    train_files = [f for f in files_changed if "train" in f]

    pr_body = f"""## AutoResearch Run

This PR was automatically generated by the autoresearch Flyte agent using Claude Code CLI.

### What changed
- **Result CSV files**: {', '.join(f'`{f}`' for f in csv_files) or 'none detected'}
- **Train code changes**: {', '.join(f'`{f}`' for f in train_files) or 'none detected'}

### All changed files
{chr(10).join(f'- `{f}`' for f in files_changed)}

---
🤖 Generated by [autoresearch Flyte agent](https://github.com/unionai-oss/autoresearch)
"""

    existing_prs = list(repo.get_pulls(state="open", head=f"unionai-oss:{branch_name}"))
    if existing_prs:
        pr = existing_prs[0]
        print(f"PR already exists: {pr.html_url}", flush=True)
    else:
        pr = repo.create_pull(
            title="feat: autoresearch results + train changes",
            body=pr_body,
            head=branch_name,
            base="master",
        )
        print(f"PR created: {pr.html_url}", flush=True)

    # --- Generate progress plot from results.tsv ---
    plot_path = repo_path / "progress.png"
    results_tsv = repo_path / "results.tsv"
    if results_tsv.exists():
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        import pandas as pd

        df = pd.read_csv(str(results_tsv), sep="\t")
        df["val_bpb"] = pd.to_numeric(df["val_bpb"], errors="coerce")
        df["memory_gb"] = pd.to_numeric(df["memory_gb"], errors="coerce")
        df["status"] = df["status"].str.strip().str.upper()

        # Filter out crashes for plotting
        valid = df[df["status"] != "CRASH"].copy()
        valid = valid.reset_index(drop=True)

        if len(valid) > 0 and valid["val_bpb"].notna().any():
            baseline_bpb = valid.loc[0, "val_bpb"]
            best = valid["val_bpb"].min()

            # Only plot points at or below baseline (the interesting region)
            below = valid[valid["val_bpb"] <= baseline_bpb + 0.0005]

            fig, ax = plt.subplots(figsize=(16, 8))

            # Plot discarded as faint background dots
            disc = below[below["status"] == "DISCARD"]
            ax.scatter(disc.index, disc["val_bpb"],
                       c="#cccccc", s=12, alpha=0.5, zorder=2, label="Discarded")

            # Plot kept experiments as prominent green dots
            kept_v = below[below["status"] == "KEEP"]
            ax.scatter(kept_v.index, kept_v["val_bpb"],
                       c="#2ecc71", s=50, zorder=4, label="Kept", edgecolors="black", linewidths=0.5)

            # Running minimum step line
            kept_mask = valid["status"] == "KEEP"
            kept_idx = valid.index[kept_mask]
            kept_bpb = valid.loc[kept_mask, "val_bpb"]
            running_min = kept_bpb.cummin()
            ax.step(kept_idx, running_min, where="post", color="#27ae60",
                    linewidth=2, alpha=0.7, zorder=3, label="Running best")

            # Label each kept experiment with its description
            for idx, bpb in zip(kept_idx, kept_bpb):
                desc = str(valid.loc[idx, "description"]).strip()
                if len(desc) > 45:
                    desc = desc[:42] + "..."
                ax.annotate(desc, (idx, bpb),
                            textcoords="offset points",
                            xytext=(6, 6), fontsize=8.0,
                            color="#1a7a3a", alpha=0.9,
                            rotation=30, ha="left", va="bottom")

            n_total = len(df)
            n_kept = len(df[df["status"] == "KEEP"])
            ax.set_xlabel("Experiment #", fontsize=12)
            ax.set_ylabel("Validation BPB (lower is better)", fontsize=12)
            ax.set_title(f"Autoresearch Progress: {n_total} Experiments, {n_kept} Kept Improvements", fontsize=14)
            ax.legend(loc="upper right", fontsize=9)
            ax.grid(True, alpha=0.2)

            margin = (baseline_bpb - best) * 0.15
            ax.set_ylim(best - margin, baseline_bpb + margin)

            plt.tight_layout()
            plt.savefig(str(plot_path), dpi=150, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved plot to {plot_path}", flush=True)

            # Upload plot to PR as a comment with base64 inline image
            import base64
            img_b64 = base64.b64encode(plot_path.read_bytes()).decode()
            pr_comment = (
                "## Autoresearch Progress\n\n"
                f"![Autoresearch Progress](data:image/png;base64,{img_b64})"
            )
            pr.create_issue_comment(pr_comment)
            print("Posted plot as PR comment.", flush=True)

            # Force-add plot to git and amend commit
            subprocess.run(["git", "add", "-f", str(plot_path)], cwd=repo_path, check=False)
            subprocess.run(
                ["git", "commit", "--amend", "--no-edit"],
                cwd=repo_path, check=False,
            )
            subprocess.run(
                ["git", "push", "-u", "origin", branch_name, "--force"],
                cwd=repo_path, check=False,
            )

            # Show plot in Flyte UI via report
            await flyte.report.replace.aio(
                f"<h2>Autoresearch Progress</h2>"
                f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
                f'<p><a href="{pr.html_url}">View PR</a></p>'
            )
            await flyte.report.flush.aio()
        else:
            print("results.tsv found but no valid val_bpb rows — skipping plot.", flush=True)
    else:
        print("results.tsv not found — skipping plot.", flush=True)

    return AutoResearchResult(
        pr_url=pr.html_url,
        pr_number=pr.number,
        branch_name=branch_name,
        files_changed=files_changed,
        success=True,
    )

# {{docs-fragment main}}
if __name__ == "__main__":
    import time

    flyte.init_from_config()

    run = flyte.with_runcontext(mode="remote").run(run_autoresearch)

    print(f"AutoResearch run started: {run.url}")
    print("Waiting for completion...")

    while True:
        try:
            run.wait()
            break
        except Exception as e:
            print(f"Connection dropped ({e}), reconnecting in 30s...")
            time.sleep(30)

    print(f"Done! See run at: {run.url}")
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autoresearch/run.py*

## The autoresearch task

The task is a long, sequential procedure. It starts by installing Node.js and Claude Code at run time, cloning the repo, configuring git, creating a branch, and loading `program.md` as the prompt:

```
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "flyte>=2.0.0b22",
#     "PyGithub>=2.5.0",
#     "matplotlib>=3.7.0",
# ]
# ///

"""
AutoResearch Agent - Runs the autoresearch workflow using Claude Code CLI in a GPU environment.

This agent:
1. Starts a GPU-enabled container
2. Installs Claude Code CLI
3. Clones the autoresearch repository
4. Points Claude Code at program.md as the prompt and lets it run
5. Commits the result (CSV + code changes in train/) and creates a PR
"""

import os
import shlex
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from github import Auth, Github

import flyte
import flyte.report
from _image import image as autoresearch_image

GITHUB_USERNAME = "parnianz"
GITHUB_EMAIL = "parnianzargham@gmail.com"
AUTORESEARCH_REPO_URL = "https://github.com/unionai-oss/autoresearch.git"
AUTORESEARCH_REPO_FULL_NAME = "unionai-oss/autoresearch"

# {{docs-fragment env}}
autoresearch_env = flyte.TaskEnvironment(
    name="autoresearch-agent",
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    secrets=[
        flyte.Secret(key="github_token", as_env_var="GITHUB_TOKEN"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=autoresearch_image,
)
# {{/docs-fragment env}}

# {{docs-fragment result}}
@dataclass
class AutoResearchResult:
    """Result of the autoresearch run."""

    pr_url: str
    pr_number: int
    branch_name: str
    files_changed: list[str]
    success: bool
    error_message: Optional[str] = None
# {{/docs-fragment result}}

def clone_repository(repo_url: str, work_dir: Path, github_token: str) -> Path:
    """Clone the autoresearch repository with authentication."""
    repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "")
    repo_path = work_dir / repo_name

    # Inject token into HTTPS URL for authentication
    authenticated_url = repo_url.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )

    if repo_path.exists():
        subprocess.run(["git", "pull"], cwd=repo_path, check=True)
    else:
        subprocess.run(["git", "clone", authenticated_url, str(repo_path)], check=True)

    return repo_path

# {{docs-fragment task}}
@autoresearch_env.task(report=True)
async def run_autoresearch() -> AutoResearchResult:
    """
    Run the autoresearch workflow end-to-end.

    Steps:
    - Clone https://github.com/unionai-oss/autoresearch
    - Configure git identity
    - Create a new branch
    - Run Claude Code CLI with program.md as the prompt
    - Commit results (CSV + train/ changes)
    - Push and open a PR against the autoresearch repo
    """
    github_token = os.environ["GITHUB_TOKEN"]
    anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]

    # --- Install Node.js + Claude Code at runtime (keeps image small and submission fast) ---
    import tarfile
    import urllib.request as _urllib

    subprocess.run(["apt-get", "update", "-y"], check=False)
    subprocess.run(["apt-get", "install", "-y", "git"], check=False)

    node_url = "https://nodejs.org/dist/v20.19.0/node-v20.19.0-linux-x64.tar.gz"
    node_tar = Path("/tmp/node.tar.gz")
    print(f"Downloading Node.js from {node_url}...", flush=True)
    _urllib.urlretrieve(node_url, node_tar)
    size_mb = node_tar.stat().st_size / 1024 / 1024
    print(f"Downloaded {size_mb:.1f} MB to {node_tar}", flush=True)
    if size_mb < 1:
        raise RuntimeError(f"Node.js download appears empty/corrupt ({size_mb:.2f} MB) — network may be restricted")
    node_dir = Path("/tmp/node")
    node_dir.mkdir(exist_ok=True)
    print("Extracting Node.js...", flush=True)
    with tarfile.open(node_tar, "r:gz") as tar:
        members = [m for m in tar.getmembers() if m.name.split("/", 1)[-1]]
        for m in members:
            m.name = m.name.split("/", 1)[-1]
        tar.extractall(str(node_dir), members=[m for m in members if m.name])

    # Add node/npm to PATH for this process and all subprocesses
    node_bin = str(node_dir / "bin")
    os.environ["PATH"] = node_bin + ":" + os.environ.get("PATH", "")
    print(f"Node version: {subprocess.run(['node', '--version'], capture_output=True, text=True).stdout.strip()}", flush=True)

    npm_prefix = "/tmp/npm-global"
    Path(npm_prefix).mkdir(exist_ok=True)
    subprocess.run(["npm", "install", "-g", "--prefix", npm_prefix, "@anthropic-ai/claude-code"], check=True)
    os.environ["PATH"] = str(Path(npm_prefix) / "bin") + ":" + os.environ["PATH"]
    print("Node.js + Claude Code installed.", flush=True)

    # --- Clone repo ---
    work_dir = Path("/tmp/autoresearch_workspace")
    work_dir.mkdir(exist_ok=True, parents=True)
    repo_path = clone_repository(AUTORESEARCH_REPO_URL, work_dir, github_token)

    # --- Git identity ---
    subprocess.run(
        ["git", "config", "--global", "user.email", GITHUB_EMAIL], check=True
    )
    subprocess.run(
        ["git", "config", "--global", "user.name", GITHUB_USERNAME], check=True
    )

    # --- Create branch ---
    import time as _time
    branch_name = f"autoresearch/claude-run-{int(_time.time())}"
    try:
        subprocess.run(
            ["git", "checkout", "-b", branch_name],
            cwd=repo_path,
            check=True,
        )
    except subprocess.CalledProcessError:
        subprocess.run(
            ["git", "checkout", branch_name],
            cwd=repo_path,
            check=True,
        )

    # --- Read program.md to use as the Claude Code prompt ---
    program_md = repo_path / "program.md"
    if not program_md.exists():
        raise FileNotFoundError(
            f"program.md not found in {repo_path}. "
            "Make sure the autoresearch repo has a program.md at its root."
        )

    program_md_content = program_md.read_text()
    print(f"Loaded prompt from program.md ({len(program_md_content)} chars)")
    # {{/docs-fragment task}}

    # Install repo dependencies before handing off to Claude
    for pip_cmd in [
        ["pip", "install", "-e", "."],
        ["pip", "install", "-r", "requirements.txt"],
    ]:
        req_file = repo_path / pip_cmd[-1] if pip_cmd[-1].startswith("req") else None
        if req_file is None or req_file.exists():
            dep_result = subprocess.run(
                pip_cmd, cwd=repo_path, capture_output=True, text=True
            )
            print(f"{' '.join(pip_cmd)}:\n{dep_result.stdout}", flush=True)
            if dep_result.returncode != 0:
                print(f"(non-fatal) {dep_result.stderr}", flush=True)

    # Wrap the program.md content with explicit instructions to write outputs to disk
    prompt = f"""You are running inside an automated GPU pipeline. You MUST write all outputs to disk as actual files.

Here are your instructions from program.md:

{program_md_content}

LOGGING INSTRUCTIONS (follow exactly):
- Before you start any training, print this exact line: [AUTORESEARCH] Training started
- Before training, print what change you are testing: [AUTORESEARCH] Change: <one line description of the code change being tested>
- When training finishes, print this exact line: [AUTORESEARCH] Training finished
- After training, print the key metric value: [AUTORESEARCH] Metric: <metric name>=<value>
- When writing results to CSV, print this exact line: [AUTORESEARCH] Writing results to CSV

IMPORTANT: After completing the above instructions, make sure you have:
1. Written the final results to a CSV file in this repository (e.g. results/results.csv or similar)
2. Saved all code changes you made to the train/ directory (or wherever the training code lives)
3. All files must be written to the current working directory so they appear in git status
If any command fails, debug and fix it rather than stopping. Do not just print results — write them to files on disk."""

    # --- Pre-flight: verify claude is installed and API key is reachable ---
    version_check = subprocess.run(
        ["claude", "--version"], capture_output=True, text=True
    )
    print(f"claude version: {version_check.stdout.strip()} | stderr: {version_check.stderr.strip()}", flush=True)
    if version_check.returncode != 0:
        raise RuntimeError(f"claude CLI not found or broken: {version_check.stderr}")

    # --- Disable Claude Code sandbox ---
    # In Kubernetes/Flyte pods, Claude Code's sandbox tries to spin up a nested container
    # which fails silently and causes file writes to go to an ephemeral space instead of
    # the real working directory. Disabling it makes writes land in the actual filesystem.
    claude_config_dir = Path("/root/.claude")
    claude_config_dir.mkdir(parents=True, exist_ok=True)
    settings = claude_config_dir / "settings.json"
    import json as _json
    existing = _json.loads(settings.read_text()) if settings.exists() else {}
    existing["sandbox"] = False
    settings.write_text(_json.dumps(existing, indent=2))
    print(f"Wrote Claude Code settings: {settings.read_text()}", flush=True)

    # --- Run Claude Code CLI ---
    # Matches swe_agent.py exactly: prompt as positional arg, CI=true enables non-interactive mode
    cmd = [
        "claude",
        "--dangerously-skip-permissions",
        "--max-turns", "100",
        "--model", "claude-haiku-4-5-20251001",
        prompt,
    ]

    print(f"Running: {shlex.join(cmd[:3])} <prompt>", flush=True)

    claude_env = {
        **os.environ,
        "ANTHROPIC_API_KEY": anthropic_api_key,
        "CLAUDE_SKIP_PERMISSIONS": "true",
        "CI": "true",  # Enables non-interactive mode (no TTY required)
    }

    # Stream output line by line so logs appear in real time instead of buffering until done
    proc = subprocess.Popen(
        cmd,
        cwd=repo_path,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # merge stderr into stdout stream
        text=True,
        env=claude_env,
    )

    stdout_lines = []
    for line in proc.stdout:
        line = line.rstrip("\n")
        print(line, flush=True)
        stdout_lines.append(line)

    proc.wait()
    full_output = "\n".join(stdout_lines)
    print(f"Claude Code exit code: {proc.returncode}", flush=True)

    if proc.returncode != 0:
        raise RuntimeError(
            f"Claude Code CLI exited with code {proc.returncode}\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Collect changed files ---
    git_status = subprocess.run(
        ["git", "status", "--porcelain"],
        cwd=repo_path,
        capture_output=True,
        text=True,
        check=True,
    )

    print(f"Git status:\n{git_status.stdout}", flush=True)

    files_changed = []
    for line in git_status.stdout.strip().splitlines():
        if line:
            # git status --porcelain: first two chars are XY status flags
            file_path = line[3:].strip()
            files_changed.append(file_path)

    # Also list all files in repo dir for debugging
    all_files = subprocess.run(
        ["find", ".", "-type", "f", "-not", "-path", "./.git/*"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"All files in repo:\n{all_files.stdout}", flush=True)

    if not files_changed:
        raise RuntimeError(
            "Claude Code ran successfully but produced no file changes.\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Commit ---
    subprocess.run(["git", "add", "."], cwd=repo_path, check=True)
    subprocess.run(["git", "add", "-f", "results.tsv"], cwd=repo_path, check=False)
    subprocess.run(["git", "add", "-f", "results/"], cwd=repo_path, check=False)
    commit_message = (
        "feat: autoresearch run via Claude Code\n\n"
        "Added research results (CSV) and updated train/ code changes.\n"
        "Generated by the autoresearch Flyte agent."
    )
    subprocess.run(
        ["git", "commit", "-m", commit_message],
        cwd=repo_path,
        check=True,
    )

    # --- Push ---
    print(f"GitHub token present: {bool(github_token)}, length: {len(github_token) if github_token else 0}", flush=True)
    authenticated_url = AUTORESEARCH_REPO_URL.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )
    subprocess.run(
        ["git", "remote", "set-url", "origin", authenticated_url],
        cwd=repo_path,
        check=True,
    )
    push_result = subprocess.run(
        ["git", "push", "-u", "origin", branch_name, "--force"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"Push stdout: {push_result.stdout}", flush=True)
    print(f"Push stderr: {push_result.stderr}", flush=True)
    if push_result.returncode != 0:
        raise RuntimeError(f"git push failed (exit {push_result.returncode}):\n{push_result.stderr}")

    # --- Create PR via PyGithub ---
    auth = Auth.Token(github_token)
    gh = Github(auth=auth)
    repo = gh.get_repo(AUTORESEARCH_REPO_FULL_NAME)

    csv_files = [f for f in files_changed if f.endswith(".csv")]
    train_files = [f for f in files_changed if "train" in f]

    pr_body = f"""## AutoResearch Run

This PR was automatically generated by the autoresearch Flyte agent using Claude Code CLI.

### What changed
- **Result CSV files**: {', '.join(f'`{f}`' for f in csv_files) or 'none detected'}
- **Train code changes**: {', '.join(f'`{f}`' for f in train_files) or 'none detected'}

### All changed files
{chr(10).join(f'- `{f}`' for f in files_changed)}

---
🤖 Generated by [autoresearch Flyte agent](https://github.com/unionai-oss/autoresearch)
"""

    existing_prs = list(repo.get_pulls(state="open", head=f"unionai-oss:{branch_name}"))
    if existing_prs:
        pr = existing_prs[0]
        print(f"PR already exists: {pr.html_url}", flush=True)
    else:
        pr = repo.create_pull(
            title="feat: autoresearch results + train changes",
            body=pr_body,
            head=branch_name,
            base="master",
        )
        print(f"PR created: {pr.html_url}", flush=True)

    # --- Generate progress plot from results.tsv ---
    plot_path = repo_path / "progress.png"
    results_tsv = repo_path / "results.tsv"
    if results_tsv.exists():
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        import pandas as pd

        df = pd.read_csv(str(results_tsv), sep="\t")
        df["val_bpb"] = pd.to_numeric(df["val_bpb"], errors="coerce")
        df["memory_gb"] = pd.to_numeric(df["memory_gb"], errors="coerce")
        df["status"] = df["status"].str.strip().str.upper()

        # Filter out crashes for plotting
        valid = df[df["status"] != "CRASH"].copy()
        valid = valid.reset_index(drop=True)

        if len(valid) > 0 and valid["val_bpb"].notna().any():
            baseline_bpb = valid.loc[0, "val_bpb"]
            best = valid["val_bpb"].min()

            # Only plot points at or below baseline (the interesting region)
            below = valid[valid["val_bpb"] <= baseline_bpb + 0.0005]

            fig, ax = plt.subplots(figsize=(16, 8))

            # Plot discarded as faint background dots
            disc = below[below["status"] == "DISCARD"]
            ax.scatter(disc.index, disc["val_bpb"],
                       c="#cccccc", s=12, alpha=0.5, zorder=2, label="Discarded")

            # Plot kept experiments as prominent green dots
            kept_v = below[below["status"] == "KEEP"]
            ax.scatter(kept_v.index, kept_v["val_bpb"],
                       c="#2ecc71", s=50, zorder=4, label="Kept", edgecolors="black", linewidths=0.5)

            # Running minimum step line
            kept_mask = valid["status"] == "KEEP"
            kept_idx = valid.index[kept_mask]
            kept_bpb = valid.loc[kept_mask, "val_bpb"]
            running_min = kept_bpb.cummin()
            ax.step(kept_idx, running_min, where="post", color="#27ae60",
                    linewidth=2, alpha=0.7, zorder=3, label="Running best")

            # Label each kept experiment with its description
            for idx, bpb in zip(kept_idx, kept_bpb):
                desc = str(valid.loc[idx, "description"]).strip()
                if len(desc) > 45:
                    desc = desc[:42] + "..."
                ax.annotate(desc, (idx, bpb),
                            textcoords="offset points",
                            xytext=(6, 6), fontsize=8.0,
                            color="#1a7a3a", alpha=0.9,
                            rotation=30, ha="left", va="bottom")

            n_total = len(df)
            n_kept = len(df[df["status"] == "KEEP"])
            ax.set_xlabel("Experiment #", fontsize=12)
            ax.set_ylabel("Validation BPB (lower is better)", fontsize=12)
            ax.set_title(f"Autoresearch Progress: {n_total} Experiments, {n_kept} Kept Improvements", fontsize=14)
            ax.legend(loc="upper right", fontsize=9)
            ax.grid(True, alpha=0.2)

            margin = (baseline_bpb - best) * 0.15
            ax.set_ylim(best - margin, baseline_bpb + margin)

            plt.tight_layout()
            plt.savefig(str(plot_path), dpi=150, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved plot to {plot_path}", flush=True)

            # Upload plot to PR as a comment with base64 inline image
            import base64
            img_b64 = base64.b64encode(plot_path.read_bytes()).decode()
            pr_comment = (
                "## Autoresearch Progress\n\n"
                f"![Autoresearch Progress](data:image/png;base64,{img_b64})"
            )
            pr.create_issue_comment(pr_comment)
            print("Posted plot as PR comment.", flush=True)

            # Force-add plot to git and amend commit
            subprocess.run(["git", "add", "-f", str(plot_path)], cwd=repo_path, check=False)
            subprocess.run(
                ["git", "commit", "--amend", "--no-edit"],
                cwd=repo_path, check=False,
            )
            subprocess.run(
                ["git", "push", "-u", "origin", branch_name, "--force"],
                cwd=repo_path, check=False,
            )

            # Show plot in Flyte UI via report
            await flyte.report.replace.aio(
                f"<h2>Autoresearch Progress</h2>"
                f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
                f'<p><a href="{pr.html_url}">View PR</a></p>'
            )
            await flyte.report.flush.aio()
        else:
            print("results.tsv found but no valid val_bpb rows — skipping plot.", flush=True)
    else:
        print("results.tsv not found — skipping plot.", flush=True)

    return AutoResearchResult(
        pr_url=pr.html_url,
        pr_number=pr.number,
        branch_name=branch_name,
        files_changed=files_changed,
        success=True,
    )

# {{docs-fragment main}}
if __name__ == "__main__":
    import time

    flyte.init_from_config()

    run = flyte.with_runcontext(mode="remote").run(run_autoresearch)

    print(f"AutoResearch run started: {run.url}")
    print("Waiting for completion...")

    while True:
        try:
            run.wait()
            break
        except Exception as e:
            print(f"Connection dropped ({e}), reconnecting in 30s...")
            time.sleep(30)

    print(f"Done! See run at: {run.url}")
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autoresearch/run.py*

From there the task:

1. Wraps the `program.md` brief with explicit logging and "write outputs to disk" instructions.
2. Disables the Claude Code sandbox (it conflicts with the Flyte pod's container) and runs the CLI non-interactively, streaming its output to the Flyte logs in real time.
3. Collects the files the agent changed via `git status`, commits them, and force-pushes the branch.
4. Opens (or reuses) a pull request with [PyGithub](https://pygithub.readthedocs.io/).
5. If the agent produced a `results.tsv`, renders a progress plot of validation bits-per-byte, attaches it to the PR, and streams it into the Flyte UI:

```
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "flyte>=2.0.0b22",
#     "PyGithub>=2.5.0",
#     "matplotlib>=3.7.0",
# ]
# ///

"""
AutoResearch Agent - Runs the autoresearch workflow using Claude Code CLI in a GPU environment.

This agent:
1. Starts a GPU-enabled container
2. Installs Claude Code CLI
3. Clones the autoresearch repository
4. Points Claude Code at program.md as the prompt and lets it run
5. Commits the result (CSV + code changes in train/) and creates a PR
"""

import os
import shlex
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from github import Auth, Github

import flyte
import flyte.report
from _image import image as autoresearch_image

GITHUB_USERNAME = "parnianz"
GITHUB_EMAIL = "parnianzargham@gmail.com"
AUTORESEARCH_REPO_URL = "https://github.com/unionai-oss/autoresearch.git"
AUTORESEARCH_REPO_FULL_NAME = "unionai-oss/autoresearch"

# {{docs-fragment env}}
autoresearch_env = flyte.TaskEnvironment(
    name="autoresearch-agent",
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    secrets=[
        flyte.Secret(key="github_token", as_env_var="GITHUB_TOKEN"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=autoresearch_image,
)
# {{/docs-fragment env}}

# {{docs-fragment result}}
@dataclass
class AutoResearchResult:
    """Result of the autoresearch run."""

    pr_url: str
    pr_number: int
    branch_name: str
    files_changed: list[str]
    success: bool
    error_message: Optional[str] = None
# {{/docs-fragment result}}

def clone_repository(repo_url: str, work_dir: Path, github_token: str) -> Path:
    """Clone the autoresearch repository with authentication."""
    repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "")
    repo_path = work_dir / repo_name

    # Inject token into HTTPS URL for authentication
    authenticated_url = repo_url.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )

    if repo_path.exists():
        subprocess.run(["git", "pull"], cwd=repo_path, check=True)
    else:
        subprocess.run(["git", "clone", authenticated_url, str(repo_path)], check=True)

    return repo_path

# {{docs-fragment task}}
@autoresearch_env.task(report=True)
async def run_autoresearch() -> AutoResearchResult:
    """
    Run the autoresearch workflow end-to-end.

    Steps:
    - Clone https://github.com/unionai-oss/autoresearch
    - Configure git identity
    - Create a new branch
    - Run Claude Code CLI with program.md as the prompt
    - Commit results (CSV + train/ changes)
    - Push and open a PR against the autoresearch repo
    """
    github_token = os.environ["GITHUB_TOKEN"]
    anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]

    # --- Install Node.js + Claude Code at runtime (keeps image small and submission fast) ---
    import tarfile
    import urllib.request as _urllib

    subprocess.run(["apt-get", "update", "-y"], check=False)
    subprocess.run(["apt-get", "install", "-y", "git"], check=False)

    node_url = "https://nodejs.org/dist/v20.19.0/node-v20.19.0-linux-x64.tar.gz"
    node_tar = Path("/tmp/node.tar.gz")
    print(f"Downloading Node.js from {node_url}...", flush=True)
    _urllib.urlretrieve(node_url, node_tar)
    size_mb = node_tar.stat().st_size / 1024 / 1024
    print(f"Downloaded {size_mb:.1f} MB to {node_tar}", flush=True)
    if size_mb < 1:
        raise RuntimeError(f"Node.js download appears empty/corrupt ({size_mb:.2f} MB) — network may be restricted")
    node_dir = Path("/tmp/node")
    node_dir.mkdir(exist_ok=True)
    print("Extracting Node.js...", flush=True)
    with tarfile.open(node_tar, "r:gz") as tar:
        members = [m for m in tar.getmembers() if m.name.split("/", 1)[-1]]
        for m in members:
            m.name = m.name.split("/", 1)[-1]
        tar.extractall(str(node_dir), members=[m for m in members if m.name])

    # Add node/npm to PATH for this process and all subprocesses
    node_bin = str(node_dir / "bin")
    os.environ["PATH"] = node_bin + ":" + os.environ.get("PATH", "")
    print(f"Node version: {subprocess.run(['node', '--version'], capture_output=True, text=True).stdout.strip()}", flush=True)

    npm_prefix = "/tmp/npm-global"
    Path(npm_prefix).mkdir(exist_ok=True)
    subprocess.run(["npm", "install", "-g", "--prefix", npm_prefix, "@anthropic-ai/claude-code"], check=True)
    os.environ["PATH"] = str(Path(npm_prefix) / "bin") + ":" + os.environ["PATH"]
    print("Node.js + Claude Code installed.", flush=True)

    # --- Clone repo ---
    work_dir = Path("/tmp/autoresearch_workspace")
    work_dir.mkdir(exist_ok=True, parents=True)
    repo_path = clone_repository(AUTORESEARCH_REPO_URL, work_dir, github_token)

    # --- Git identity ---
    subprocess.run(
        ["git", "config", "--global", "user.email", GITHUB_EMAIL], check=True
    )
    subprocess.run(
        ["git", "config", "--global", "user.name", GITHUB_USERNAME], check=True
    )

    # --- Create branch ---
    import time as _time
    branch_name = f"autoresearch/claude-run-{int(_time.time())}"
    try:
        subprocess.run(
            ["git", "checkout", "-b", branch_name],
            cwd=repo_path,
            check=True,
        )
    except subprocess.CalledProcessError:
        subprocess.run(
            ["git", "checkout", branch_name],
            cwd=repo_path,
            check=True,
        )

    # --- Read program.md to use as the Claude Code prompt ---
    program_md = repo_path / "program.md"
    if not program_md.exists():
        raise FileNotFoundError(
            f"program.md not found in {repo_path}. "
            "Make sure the autoresearch repo has a program.md at its root."
        )

    program_md_content = program_md.read_text()
    print(f"Loaded prompt from program.md ({len(program_md_content)} chars)")
    # {{/docs-fragment task}}

    # Install repo dependencies before handing off to Claude
    for pip_cmd in [
        ["pip", "install", "-e", "."],
        ["pip", "install", "-r", "requirements.txt"],
    ]:
        req_file = repo_path / pip_cmd[-1] if pip_cmd[-1].startswith("req") else None
        if req_file is None or req_file.exists():
            dep_result = subprocess.run(
                pip_cmd, cwd=repo_path, capture_output=True, text=True
            )
            print(f"{' '.join(pip_cmd)}:\n{dep_result.stdout}", flush=True)
            if dep_result.returncode != 0:
                print(f"(non-fatal) {dep_result.stderr}", flush=True)

    # Wrap the program.md content with explicit instructions to write outputs to disk
    prompt = f"""You are running inside an automated GPU pipeline. You MUST write all outputs to disk as actual files.

Here are your instructions from program.md:

{program_md_content}

LOGGING INSTRUCTIONS (follow exactly):
- Before you start any training, print this exact line: [AUTORESEARCH] Training started
- Before training, print what change you are testing: [AUTORESEARCH] Change: <one line description of the code change being tested>
- When training finishes, print this exact line: [AUTORESEARCH] Training finished
- After training, print the key metric value: [AUTORESEARCH] Metric: <metric name>=<value>
- When writing results to CSV, print this exact line: [AUTORESEARCH] Writing results to CSV

IMPORTANT: After completing the above instructions, make sure you have:
1. Written the final results to a CSV file in this repository (e.g. results/results.csv or similar)
2. Saved all code changes you made to the train/ directory (or wherever the training code lives)
3. All files must be written to the current working directory so they appear in git status
If any command fails, debug and fix it rather than stopping. Do not just print results — write them to files on disk."""

    # --- Pre-flight: verify claude is installed and API key is reachable ---
    version_check = subprocess.run(
        ["claude", "--version"], capture_output=True, text=True
    )
    print(f"claude version: {version_check.stdout.strip()} | stderr: {version_check.stderr.strip()}", flush=True)
    if version_check.returncode != 0:
        raise RuntimeError(f"claude CLI not found or broken: {version_check.stderr}")

    # --- Disable Claude Code sandbox ---
    # In Kubernetes/Flyte pods, Claude Code's sandbox tries to spin up a nested container
    # which fails silently and causes file writes to go to an ephemeral space instead of
    # the real working directory. Disabling it makes writes land in the actual filesystem.
    claude_config_dir = Path("/root/.claude")
    claude_config_dir.mkdir(parents=True, exist_ok=True)
    settings = claude_config_dir / "settings.json"
    import json as _json
    existing = _json.loads(settings.read_text()) if settings.exists() else {}
    existing["sandbox"] = False
    settings.write_text(_json.dumps(existing, indent=2))
    print(f"Wrote Claude Code settings: {settings.read_text()}", flush=True)

    # --- Run Claude Code CLI ---
    # Matches swe_agent.py exactly: prompt as positional arg, CI=true enables non-interactive mode
    cmd = [
        "claude",
        "--dangerously-skip-permissions",
        "--max-turns", "100",
        "--model", "claude-haiku-4-5-20251001",
        prompt,
    ]

    print(f"Running: {shlex.join(cmd[:3])} <prompt>", flush=True)

    claude_env = {
        **os.environ,
        "ANTHROPIC_API_KEY": anthropic_api_key,
        "CLAUDE_SKIP_PERMISSIONS": "true",
        "CI": "true",  # Enables non-interactive mode (no TTY required)
    }

    # Stream output line by line so logs appear in real time instead of buffering until done
    proc = subprocess.Popen(
        cmd,
        cwd=repo_path,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # merge stderr into stdout stream
        text=True,
        env=claude_env,
    )

    stdout_lines = []
    for line in proc.stdout:
        line = line.rstrip("\n")
        print(line, flush=True)
        stdout_lines.append(line)

    proc.wait()
    full_output = "\n".join(stdout_lines)
    print(f"Claude Code exit code: {proc.returncode}", flush=True)

    if proc.returncode != 0:
        raise RuntimeError(
            f"Claude Code CLI exited with code {proc.returncode}\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Collect changed files ---
    git_status = subprocess.run(
        ["git", "status", "--porcelain"],
        cwd=repo_path,
        capture_output=True,
        text=True,
        check=True,
    )

    print(f"Git status:\n{git_status.stdout}", flush=True)

    files_changed = []
    for line in git_status.stdout.strip().splitlines():
        if line:
            # git status --porcelain: first two chars are XY status flags
            file_path = line[3:].strip()
            files_changed.append(file_path)

    # Also list all files in repo dir for debugging
    all_files = subprocess.run(
        ["find", ".", "-type", "f", "-not", "-path", "./.git/*"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"All files in repo:\n{all_files.stdout}", flush=True)

    if not files_changed:
        raise RuntimeError(
            "Claude Code ran successfully but produced no file changes.\n"
            f"output: {full_output[-2000:]}"
        )

    # --- Commit ---
    subprocess.run(["git", "add", "."], cwd=repo_path, check=True)
    subprocess.run(["git", "add", "-f", "results.tsv"], cwd=repo_path, check=False)
    subprocess.run(["git", "add", "-f", "results/"], cwd=repo_path, check=False)
    commit_message = (
        "feat: autoresearch run via Claude Code\n\n"
        "Added research results (CSV) and updated train/ code changes.\n"
        "Generated by the autoresearch Flyte agent."
    )
    subprocess.run(
        ["git", "commit", "-m", commit_message],
        cwd=repo_path,
        check=True,
    )

    # --- Push ---
    print(f"GitHub token present: {bool(github_token)}, length: {len(github_token) if github_token else 0}", flush=True)
    authenticated_url = AUTORESEARCH_REPO_URL.replace(
        "https://", f"https://{GITHUB_USERNAME}:{github_token}@"
    )
    subprocess.run(
        ["git", "remote", "set-url", "origin", authenticated_url],
        cwd=repo_path,
        check=True,
    )
    push_result = subprocess.run(
        ["git", "push", "-u", "origin", branch_name, "--force"],
        cwd=repo_path,
        capture_output=True,
        text=True,
    )
    print(f"Push stdout: {push_result.stdout}", flush=True)
    print(f"Push stderr: {push_result.stderr}", flush=True)
    if push_result.returncode != 0:
        raise RuntimeError(f"git push failed (exit {push_result.returncode}):\n{push_result.stderr}")

    # --- Create PR via PyGithub ---
    auth = Auth.Token(github_token)
    gh = Github(auth=auth)
    repo = gh.get_repo(AUTORESEARCH_REPO_FULL_NAME)

    csv_files = [f for f in files_changed if f.endswith(".csv")]
    train_files = [f for f in files_changed if "train" in f]

    pr_body = f"""## AutoResearch Run

This PR was automatically generated by the autoresearch Flyte agent using Claude Code CLI.

### What changed
- **Result CSV files**: {', '.join(f'`{f}`' for f in csv_files) or 'none detected'}
- **Train code changes**: {', '.join(f'`{f}`' for f in train_files) or 'none detected'}

### All changed files
{chr(10).join(f'- `{f}`' for f in files_changed)}

---
🤖 Generated by [autoresearch Flyte agent](https://github.com/unionai-oss/autoresearch)
"""

    existing_prs = list(repo.get_pulls(state="open", head=f"unionai-oss:{branch_name}"))
    if existing_prs:
        pr = existing_prs[0]
        print(f"PR already exists: {pr.html_url}", flush=True)
    else:
        pr = repo.create_pull(
            title="feat: autoresearch results + train changes",
            body=pr_body,
            head=branch_name,
            base="master",
        )
        print(f"PR created: {pr.html_url}", flush=True)

    # --- Generate progress plot from results.tsv ---
    plot_path = repo_path / "progress.png"
    results_tsv = repo_path / "results.tsv"
    if results_tsv.exists():
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        import pandas as pd

        df = pd.read_csv(str(results_tsv), sep="\t")
        df["val_bpb"] = pd.to_numeric(df["val_bpb"], errors="coerce")
        df["memory_gb"] = pd.to_numeric(df["memory_gb"], errors="coerce")
        df["status"] = df["status"].str.strip().str.upper()

        # Filter out crashes for plotting
        valid = df[df["status"] != "CRASH"].copy()
        valid = valid.reset_index(drop=True)

        if len(valid) > 0 and valid["val_bpb"].notna().any():
            baseline_bpb = valid.loc[0, "val_bpb"]
            best = valid["val_bpb"].min()

            # Only plot points at or below baseline (the interesting region)
            below = valid[valid["val_bpb"] <= baseline_bpb + 0.0005]

            fig, ax = plt.subplots(figsize=(16, 8))

            # Plot discarded as faint background dots
            disc = below[below["status"] == "DISCARD"]
            ax.scatter(disc.index, disc["val_bpb"],
                       c="#cccccc", s=12, alpha=0.5, zorder=2, label="Discarded")

            # Plot kept experiments as prominent green dots
            kept_v = below[below["status"] == "KEEP"]
            ax.scatter(kept_v.index, kept_v["val_bpb"],
                       c="#2ecc71", s=50, zorder=4, label="Kept", edgecolors="black", linewidths=0.5)

            # Running minimum step line
            kept_mask = valid["status"] == "KEEP"
            kept_idx = valid.index[kept_mask]
            kept_bpb = valid.loc[kept_mask, "val_bpb"]
            running_min = kept_bpb.cummin()
            ax.step(kept_idx, running_min, where="post", color="#27ae60",
                    linewidth=2, alpha=0.7, zorder=3, label="Running best")

            # Label each kept experiment with its description
            for idx, bpb in zip(kept_idx, kept_bpb):
                desc = str(valid.loc[idx, "description"]).strip()
                if len(desc) > 45:
                    desc = desc[:42] + "..."
                ax.annotate(desc, (idx, bpb),
                            textcoords="offset points",
                            xytext=(6, 6), fontsize=8.0,
                            color="#1a7a3a", alpha=0.9,
                            rotation=30, ha="left", va="bottom")

            n_total = len(df)
            n_kept = len(df[df["status"] == "KEEP"])
            ax.set_xlabel("Experiment #", fontsize=12)
            ax.set_ylabel("Validation BPB (lower is better)", fontsize=12)
            ax.set_title(f"Autoresearch Progress: {n_total} Experiments, {n_kept} Kept Improvements", fontsize=14)
            ax.legend(loc="upper right", fontsize=9)
            ax.grid(True, alpha=0.2)

            margin = (baseline_bpb - best) * 0.15
            ax.set_ylim(best - margin, baseline_bpb + margin)

            plt.tight_layout()
            plt.savefig(str(plot_path), dpi=150, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved plot to {plot_path}", flush=True)

            # Upload plot to PR as a comment with base64 inline image
            import base64
            img_b64 = base64.b64encode(plot_path.read_bytes()).decode()
            pr_comment = (
                "## Autoresearch Progress\n\n"
                f"![Autoresearch Progress](data:image/png;base64,{img_b64})"
            )
            pr.create_issue_comment(pr_comment)
            print("Posted plot as PR comment.", flush=True)

            # Force-add plot to git and amend commit
            subprocess.run(["git", "add", "-f", str(plot_path)], cwd=repo_path, check=False)
            subprocess.run(
                ["git", "commit", "--amend", "--no-edit"],
                cwd=repo_path, check=False,
            )
            subprocess.run(
                ["git", "push", "-u", "origin", branch_name, "--force"],
                cwd=repo_path, check=False,
            )

            # Show plot in Flyte UI via report
            await flyte.report.replace.aio(
                f"<h2>Autoresearch Progress</h2>"
                f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
                f'<p><a href="{pr.html_url}">View PR</a></p>'
            )
            await flyte.report.flush.aio()
        else:
            print("results.tsv found but no valid val_bpb rows — skipping plot.", flush=True)
    else:
        print("results.tsv not found — skipping plot.", flush=True)

    return AutoResearchResult(
        pr_url=pr.html_url,
        pr_number=pr.number,
        branch_name=branch_name,
        files_changed=files_changed,
        success=True,
    )

# {{docs-fragment main}}
if __name__ == "__main__":
    import time

    flyte.init_from_config()

    run = flyte.with_runcontext(mode="remote").run(run_autoresearch)

    print(f"AutoResearch run started: {run.url}")
    print("Waiting for completion...")

    while True:
        try:
            run.wait()
            break
        except Exception as e:
            print(f"Connection dropped ({e}), reconnecting in 30s...")
            time.sleep(30)

    print(f"Done! See run at: {run.url}")
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autoresearch/run.py*

The entry point submits the task in `remote` mode and reconnects automatically if the client connection drops during the long run.

## Run the agent

### Create secrets

Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/) and a [GitHub personal access token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) with permission to push and open PRs on the target repository.

Register both as Flyte secrets. The key names must match those declared in the `TaskEnvironment`:

```
flyte create secret github_token <YOUR_GITHUB_TOKEN>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Prepare the research repository

The target repository must contain a `program.md` at its root describing the research task for the agent. Point `AUTORESEARCH_REPO_URL` / `AUTORESEARCH_REPO_FULL_NAME` (and the git identity constants) at a repo you control.

### Run remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/autoresearch):

```
cd v2/tutorials/autoresearch
python run.py
```

This task runs remotely (it needs a GPU and network access). Follow the printed run URL to watch the agent's logs stream in, and open the run's report panel to see the progress plot once results are available. When the task finishes, the returned `AutoResearchResult` contains the pull request URL.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/parallelized-autoresearch-agent ===

# Parallelized autoresearch agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/parallelized_autoresearch).

This tutorial extends the [Autoresearch agent](../autoresearch/_index) pattern with a code-mode MLE agent that plans **batches** of training experiments, saves distinct `train.py` edits, and runs them **in parallel** via `flyte.map`. It follows the [karpathy/autoresearch](https://github.com/karpathy/autoresearch) loop — minimize validation bits-per-byte on a TinyGPT variant — but orchestrates fan-out batches with durable Flyte tasks and [unionai-sandbox](https://www.union.ai/docs/v2/union/user-guide/sandboxing/_index) execution.

Compared to the single-threaded Claude Code autoresearch tutorial, this agent:

- Edits full `train.py` source (upstream karpathy style) instead of calling a remote coding CLI
- Uses **`code_mode=True`** so the LLM writes Python plans that call batch tools such as `run_experiment_batch`
- Persists a **leaderboard**, code-edit history, and batch plans in `MemoryStore`
- **Right-sizes each experiment** with an LLM via a `@tool` **`call_handler`**, then retries on Flyte or sandbox OOM by bumping memory

Each experiment has different compute needs (wider models, larger batch sizes, longer training loops). A single static `flyte.Resources` on the task would either waste cluster memory or OOM on the heavy configs. Instead, this example uses the same [`call_handler` pattern](https://www.union.ai/docs/v2/union/user-guide/build-agent/flyte-agents/page.md) as the Flyte SDK self-correcting agent: before every run, a sizing LLM reads the tool name, docstring, and call arguments and returns a JSON resource spec; the handler applies it with `tool_fn.target.override(resources=...).aio(**kwargs)` and retries with more memory when needed.

## Define the task environments

The example uses three environments — bundle preparation, sandbox experiments, and the agent driver — sharing a Debian-based image with PyTorch and sandbox tooling.

```
"""Shared Flyte environments and climbmix dataset bundle tasks."""

from __future__ import annotations

import os
import tempfile
from dataclasses import dataclass
from pathlib import Path

import flyte
from flyte.io import Dir

from autoresearch_types import DatasetProfile
from autoresearch_types import DEFAULT_NUM_SHARDS

TRAIN_PIP_PACKAGES = ["torch", "numpy", "pyarrow", "requests", "tiktoken", "rustbpe"]

_TUTORIAL_DIR = Path(__file__).parent
_INCLUDE = [str(p) for p in sorted(_TUTORIAL_DIR.glob("*.py"))]

image = flyte.Image.from_debian_base(name="mle-autoresearch").with_pip_packages(
    "litellm",
    "httpx",
    "pydantic-monty",
    "unionai-sandbox[flyte]",
    *TRAIN_PIP_PACKAGES,
)

bundle_env = flyte.TaskEnvironment(
    name="autoresearch-bundle",
    resources=flyte.Resources(cpu=4, memory="8Gi"),
    image=image,
    include=_INCLUDE,
)

experiment_env = flyte.TaskEnvironment(
    name="autoresearch-experiment",
    resources=flyte.Resources(cpu=2, memory="2Gi"),
    image=image,
    include=_INCLUDE,
    secrets=[flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY")],
)

# {{docs-fragment env}}
agent_env = flyte.TaskEnvironment(
    name="autoresearch-agent",
    resources=flyte.Resources(cpu=1, memory="2Gi"),
    image=image,
    include=_INCLUDE,
    secrets=[flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY")],
    depends_on=[experiment_env, bundle_env],
)
# {{/docs-fragment env}}

@dataclass
class AutoresearchBundle:
    data_dir: Dir
    tokenizer_dir: Dir

@bundle_env.task(cache="auto")
async def build_bundle(num_shards: int = DEFAULT_NUM_SHARDS, download_workers: int = 4) -> AutoresearchBundle:
    """Download climbmix shards + train the BPE tokenizer; cache the result."""
    import prepare

    cache = tempfile.mkdtemp(prefix="autoresearch-cache-")
    os.environ["AUTORESEARCH_CACHE"] = cache
    prepare.download_data(num_shards, download_workers=download_workers)
    prepare.train_tokenizer()
    data_dir = await Dir.from_local(prepare.data_dir())
    tokenizer_dir = await Dir.from_local(prepare.tokenizer_dir())
    return AutoresearchBundle(data_dir=data_dir, tokenizer_dir=tokenizer_dir)

@bundle_env.task(cache="auto")
async def profile_bundle(bundle: AutoresearchBundle) -> DatasetProfile:
    """Summarize the prepared bundle for the agent's context."""
    import prepare

    data_dir = await bundle.data_dir.download()
    tokenizer_dir = await bundle.tokenizer_dir.download()
    parquet_files = sorted(p.name for p in Path(data_dir).glob("*.parquet"))
    data_bytes = sum(p.stat().st_size for p in Path(data_dir).glob("**/*") if p.is_file())
    tok_bytes = sum(p.stat().st_size for p in Path(tokenizer_dir).glob("**/*") if p.is_file())
    return DatasetProfile(
        n_parquet_files=len(parquet_files),
        parquet_files=parquet_files,
        vocab_size=prepare.VOCAB_SIZE,
        data_bytes=data_bytes,
        tokenizer_bytes=tok_bytes,
    )

async def materialize_cache(bundle: AutoresearchBundle) -> str:
    """Download the bundle into an AUTORESEARCH_CACHE-shaped scratch dir."""
    cache = tempfile.mkdtemp(prefix="autoresearch-run-")
    os.environ["AUTORESEARCH_CACHE"] = cache
    await bundle.data_dir.download(local_path=os.path.join(cache, "data"))
    await bundle.tokenizer_dir.download(local_path=os.path.join(cache, "tokenizer"))
    return cache
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/parallelized_autoresearch/bundle.py*

Supporting modules (`train.py`, `prepare.py`, `tools.py`, and `ui.py`) live alongside the entry point in the example directory.

## Right-size experiments with `call_handler`

The right-sizing logic lives in `tools.py`. `execute_with_right_sizing` asks the LLM for a resource estimate, runs the underlying `@env.task` with `override(resources=...)`, and loops on `flyte.errors.OOMError` or a sandbox-reported OOM flag until the run succeeds or retries are exhausted:

```
"""Agent tools, sandbox execution, and memory helpers for parallelized autoresearch."""

from __future__ import annotations

import asyncio
import dataclasses
import hashlib
import json
import re
import textwrap
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import flyte
import flyte.errors
from flyte.ai.agents import LLMCallable, LLMMessage, MemoryStore, ToolFn, tool
from flyte.ai.agents._llm import _default_call_llm

from autoresearch_types import (
    CONFIG_ONLY_EDIT_LIMIT,
    DEFAULT_NUM_SHARDS,
    DatasetProfile,
    ExperimentConfig,
    HypothesisEntry,
    MAX_DEVICE_BATCH_SIZE,
    MAX_MAX_STEPS,
    MAX_N_EMBD,
    MAX_N_HEAD,
    MAX_N_LAYER,
)
from bundle import agent_env, build_bundle, bundle_env, profile_bundle

MEMORY_KEY_FANOUT = "parallelized-autoresearch"

MAX_LLM_RETRIES = 5
INITIAL_BACKOFF_SEC = 2.0

async def call_llm(
    model: str,
    system: str,
    messages: list[dict[str, Any]],
    tools: list[dict[str, Any]] | None,
) -> LLMMessage:
    """Call litellm via the Flyte default callback, retrying transient provider errors."""
    import litellm

    backoff = INITIAL_BACKOFF_SEC
    last_exc: Exception | None = None
    for attempt in range(MAX_LLM_RETRIES):
        try:
            return await _default_call_llm(model, system, messages, tools)
        except litellm.InternalServerError as exc:
            last_exc = exc
            if attempt >= MAX_LLM_RETRIES - 1:
                break
            flyte.logger.warning(
                "LLM InternalServerError (attempt %d/%d); retrying in %.1fs: %s",
                attempt + 1,
                MAX_LLM_RETRIES,
                backoff,
                exc,
            )
            await asyncio.sleep(backoff)
            backoff *= 2
    assert last_exc is not None
    raise last_exc

RESOURCE_FLOOR = flyte.Resources(cpu=2, memory="2Gi")
RESOURCE_CEILING = flyte.Resources(cpu=16, memory="32Gi")
_MEM_RE = re.compile(r"^\s*([0-9]*\.?[0-9]+)\s*([A-Za-z]+)?\s*$")

def _memory_to_mib(memory: str | None) -> int:
    if not memory:
        return 2048
    match = _MEM_RE.match(memory)
    if not match:
        return 2048
    value = float(match.group(1))
    unit = (match.group(2) or "Mi").lower()
    if unit in ("gi", "g", "gb"):
        return int(value * 1024)
    if unit in ("mi", "m", "mb"):
        return int(value)
    if unit in ("ki", "k", "kb"):
        return max(1, int(value // 1024))
    return int(value)

def _mib_to_memory(mib: int) -> str:
    if mib >= 1024 and mib % 1024 == 0:
        return f"{mib // 1024}Gi"
    return f"{mib}Mi"

def _cap_resources(resources: flyte.Resources) -> flyte.Resources:
    floor_cpu = int(RESOURCE_FLOOR.cpu or 2)
    ceil_cpu = int(RESOURCE_CEILING.cpu or 16)
    cpu = int(resources.cpu or floor_cpu)
    cpu = max(floor_cpu, min(ceil_cpu, cpu))

    floor_mib = _memory_to_mib(
        RESOURCE_FLOOR.memory if isinstance(RESOURCE_FLOOR.memory, str) else "2Gi"
    )
    ceil_mib = _memory_to_mib(
        RESOURCE_CEILING.memory if isinstance(RESOURCE_CEILING.memory, str) else "32Gi"
    )
    mem_mib = _memory_to_mib(resources.memory if isinstance(resources.memory, str) else None)
    mem_mib = max(floor_mib, min(ceil_mib, mem_mib))
    return flyte.Resources(cpu=cpu, memory=_mib_to_memory(mem_mib))

def _ensure_oom_increase(resources: flyte.Resources, previous: flyte.Resources) -> flyte.Resources:
    """If memory did not grow after OOM, bump deterministically up to the ceiling."""
    prev_mib = _memory_to_mib(previous.memory if isinstance(previous.memory, str) else None)
    new_mib = _memory_to_mib(resources.memory if isinstance(resources.memory, str) else None)
    if new_mib <= prev_mib:
        ceil_mib = _memory_to_mib(
            RESOURCE_CEILING.memory if isinstance(RESOURCE_CEILING.memory, str) else "32Gi"
        )
        new_mib = min(ceil_mib, max(prev_mib * 2, prev_mib + 2048))
        resources = dataclasses.replace(resources, memory=_mib_to_memory(new_mib))
    prev_cpu = int(previous.cpu or RESOURCE_FLOOR.cpu or 2)
    new_cpu = int(resources.cpu or prev_cpu)
    if new_cpu < prev_cpu:
        resources = dataclasses.replace(resources, cpu=prev_cpu)
    return _cap_resources(resources)

def bump_memory(resources: flyte.Resources) -> flyte.Resources:
    """Deterministic memory bump after OOM."""
    return _ensure_oom_increase(resources, resources)

MAX_OOM_RETRIES = 3

RESOURCE_SIZING_SYSTEM_PROMPT = """\
You are a Kubernetes capacity planner for Flyte autoresearch sandbox training runs. \
Given a task's name, its docstring, and the concrete arguments it is about to be \
called with, estimate the *minimum sensible* compute it needs to finish without \
being OOM-killed, while not wildly over-provisioning.

Reason about the work implied by the arguments:
- TinyGPT training is memory-bound: scale with model width/depth (n_layer, n_embd, \
n_head), device_batch_size, and sequence length (512 in this workshop).
- Larger models and batch sizes need more RAM; CPU helps dataloader throughput but \
memory is usually the bottleneck.
- Sandbox runs are capped at a short time_budget_sec wall clock — prefer enough \
memory to survive peak activation usage over extra CPU.

Respond with ONLY a JSON object (no prose, no code fences) with any of these keys:
  - "cpu":    a number of cores, e.g. 2, 4, 8
  - "memory": a Kubernetes memory string, e.g. "4Gi", "16Gi"
  - "disk":   a Kubernetes disk string, e.g. "10Gi" (omit unless large I/O)
Omit a key to accept the default. Do not include any other keys. No GPUs are \
available on this cluster.

Example response: {"cpu": 4, "memory": "8Gi"}
"""

_ALLOWED_RESOURCE_KEYS = ("cpu", "memory", "disk", "shm")
_JSON_OBJECT_RE = re.compile(r"\{.*\}", re.DOTALL)

def _extract_json(text: str | None) -> dict[str, Any]:
    """Best-effort extraction of a single JSON object from an LLM reply."""
    if not text:
        return {}
    match = _JSON_OBJECT_RE.search(text)
    if not match:
        return {}
    try:
        parsed = json.loads(match.group(0))
    except json.JSONDecodeError:
        return {}
    return parsed if isinstance(parsed, dict) else {}

def _resources_from_spec(spec: dict[str, Any], floor: flyte.Resources) -> flyte.Resources:
    """Merge an LLM-produced spec onto the floor, keeping only known keys."""
    kwargs: dict[str, Any] = {
        "cpu": floor.cpu,
        "memory": floor.memory,
        "gpu": floor.gpu,
        "disk": floor.disk,
        "shm": floor.shm,
    }
    for key in _ALLOWED_RESOURCE_KEYS:
        value = spec.get(key)
        if value in (None, "", "null"):
            continue
        kwargs[key] = value
    try:
        return _cap_resources(flyte.Resources(**kwargs))
    except Exception as exc:  # pragma: no cover - defensive against bad model output
        flyte.logger.warning("Invalid resource spec %s (%s); falling back to floor.", spec, exc)
        return floor

async def estimate_resources(
    call_llm: LLMCallable,
    model: str,
    tool_name: str,
    description: str,
    args: dict[str, Any],
) -> flyte.Resources:
    """Ask the LLM to size the compute for a single tool call."""
    user = json.dumps({"tool": tool_name, "description": description, "arguments": args}, default=str)
    try:
        reply = await call_llm(
            model,
            RESOURCE_SIZING_SYSTEM_PROMPT,
            [{"role": "user", "content": user}],
            None,
        )
        spec = _extract_json(reply.content)
    except Exception as exc:  # pragma: no cover - never let sizing break the tool
        flyte.logger.warning("Resource right-sizing LLM call failed (%s); using floor.", exc)
        spec = {}
    resources = _resources_from_spec(spec, RESOURCE_FLOOR)
    flyte.logger.info("right-size %s %s -> %s", tool_name, args, resources)
    return resources

# {{docs-fragment right_size}}
async def execute_with_right_sizing(
    call_llm: LLMCallable,
    target_task: Any,
    *,
    model: str,
    tool_name: str,
    description: str,
    max_oom_retries: int = MAX_OOM_RETRIES,
    **kwargs: Any,
) -> dict:
    """LLM-size *target_task*, run it, and retry with more memory on OOM."""
    resources = await estimate_resources(call_llm, model, tool_name, description, kwargs)
    attempt = 0
    while True:
        try:
            with flyte.group(f"{tool_name}-attempt-{attempt + 1}"):
                result = await target_task.override(resources=resources).aio(**kwargs)
        except flyte.errors.OOMError:
            if attempt >= max_oom_retries:
                flyte.logger.error("%s Flyte OOM after %d retries; giving up.", tool_name, attempt)
                raise
            resources = bump_memory(resources)
            attempt += 1
            flyte.logger.warning(
                "%s Flyte OOM; retrying with memory=%s",
                tool_name,
                resources.memory,
            )
            continue

        if isinstance(result, dict):
            result["resources"] = f"cpu={resources.cpu}, mem={resources.memory}"
            result["oom_retries"] = attempt

        if isinstance(result, dict) and result.get("oom"):
            if attempt >= max_oom_retries:
                return result
            resources = bump_memory(resources)
            attempt += 1
            flyte.logger.warning(
                "%s sandbox OOM; retrying with memory=%s",
                tool_name,
                resources.memory,
            )
            continue

        return result

def right_sizing_handler(*, max_oom_retries: int = MAX_OOM_RETRIES):
    """Build a ``@tool`` ``call_handler`` that right-sizes and self-heals on OOM."""

    async def handle(call_llm: LLMCallable, tool_fn: ToolFn, **kwargs: Any) -> Any:
        return await execute_with_right_sizing(
            call_llm,
            tool_fn.target,
            model=tool_fn.model,
            tool_name=tool_fn.name,
            description=tool_fn.description,
            max_oom_retries=max_oom_retries,
            **kwargs,
        )

    return handle

right_size = right_sizing_handler(max_oom_retries=MAX_OOM_RETRIES)
# {{/docs-fragment right_size}}

def _find_leaderboard_entry(entries: list[dict[str, Any]], title: str) -> dict[str, Any] | None:
    title_lower = title.strip().lower()
    for entry in entries:
        if str(entry.get("title", "")).strip().lower() == title_lower:
            return entry
    for entry in entries:
        if title_lower in str(entry.get("title", "")).strip().lower():
            return entry
    return None

@tool
@agent_env.task(retries=3)
async def search_arxiv(query: str, max_results: int = 4) -> str:
    """Search arXiv for recent papers relevant to the next experiment.

    Use this to gather external context on architectures, optimizers, or
    evaluation metrics before proposing a new TinyGPT configuration.

    Args:
        query: Free-text search query, e.g. ``small language model depth width``.
        max_results: Maximum number of papers to return (default 4).

    Returns:
        A markdown-ish bullet list of titles and short summaries, or a note
        if the search failed or returned nothing.
    """
    import httpx

    if not (query and query.strip()):
        return "(empty query; skip literature search)"

    url = "https://export.arxiv.org/api/query"
    params = {"search_query": f"all:{query}", "start": 0, "max_results": max_results}
    try:
        async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
            resp = await client.get(url, params=params)
            resp.raise_for_status()
        root = ET.fromstring(resp.text)
        ns = {"atom": "http://www.w3.org/2005/Atom"}
        lines: list[str] = []
        for entry in root.findall("atom:entry", ns)[:max_results]:
            title_el = entry.find("atom:title", ns)
            title = " ".join((title_el.text or "").split())
            summary_el = entry.find("atom:summary", ns)
            summary = " ".join((summary_el.text or "").split())[:400]
            lines.append(f"- {title}\n  {summary}")
        return "\n".join(lines) if lines else "(no arXiv results; proceed without external context)"
    except (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError) as exc:
        return f"(literature search failed: {exc})"
    except httpx.HTTPStatusError as exc:
        if exc.response.status_code >= 500:
            return f"(literature search failed: {exc})"
        raise

@tool
@bundle_env.task(cache="auto")
async def inspect_dataset(num_shards: int = DEFAULT_NUM_SHARDS) -> dict:
    """Inspect the prepared climbmix corpus and BPE tokenizer bundle.

    Call this at the start of a research session to understand what data you
    are training on before spending experiment budget.

    Args:
        num_shards: Number of climbmix parquet shards to include in the bundle.

    Returns:
        A dict with shard/file metadata, vocab size, byte counts, and fixed
        training constants (``max_seq_len``, ``val_metric``).
    """
    import prepare

    bundle = await build_bundle(num_shards=num_shards)
    profile: DatasetProfile = await profile_bundle(bundle)
    return {
        **dataclasses.asdict(profile),
        "max_seq_len": prepare.MAX_SEQ_LEN,
        "val_metric": "val_bpb (lower is better)",
        "corpus": "karpathy/climbmix-400b-shuffle",
    }

@tool
@agent_env.task
async def record_hypothesis(
    title: str,
    hypothesis: str,
    expected_effect: str,
    memory_key: str = MEMORY_KEY_FANOUT,
) -> dict:
    """Record a structured hypothesis before running an experiment.

    Persists to the agent's keyed memory so later runs can see what you
    expected and whether it panned out.

    Args:
        title: Experiment title this hypothesis applies to.
        hypothesis: What you are trying and why.
        expected_effect: How you expect val_bpb to move (e.g. ``decrease ~5%``).
        memory_key: Memory namespace (use the key from your directive).

    Returns:
        The recorded hypothesis entry.
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    prior: list[dict[str, Any]] = await memory.read_json.aio("memory/hypotheses.json", default=[])
    entry = HypothesisEntry(
        title=title,
        hypothesis=hypothesis,
        expected_effect=expected_effect,
        recorded_at=datetime.now(timezone.utc).isoformat(timespec="seconds"),
    )
    prior.append(dataclasses.asdict(entry))
    await memory.write_json.aio(
        "memory/hypotheses.json",
        prior,
        actor="mle-autoresearch-agent",
        reason=f"hypothesis for {title}",
    )
    await memory.save.aio()
    return dataclasses.asdict(entry)

@tool
@agent_env.task
async def get_leaderboard(memory_key: str = MEMORY_KEY_FANOUT) -> dict:
    """Return the persisted experiment leaderboard from agent memory.

    Use this to recall prior runs across sessions. Experiments from the
    *current* session also appear in your tool-call transcript.

    Args:
        memory_key: Memory namespace (use the key from your directive).

    Returns:
        A dict with ``entries`` (list) and ``best`` (entry or null).
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    entries: list[dict[str, Any]] = await memory.read_json.aio("memory/leaderboard.json", default=[])
    best: dict[str, Any] | None = None
    best_val = float("inf")
    for entry in entries:
        val = entry.get("val_bpb")
        if val is not None and float(val) < best_val:
            best_val = float(val)
            best = entry
    best_f = best_val if best_val != float("inf") else None
    enriched: list[dict[str, Any]] = []
    for entry in entries:
        val = entry.get("val_bpb")
        val_f = float(val) if val is not None else None
        enriched.append(
            {
                **entry,
                "beat_best": val_f is not None and best_f is not None and val_f <= best_f,
                "delta_vs_best": (val_f - best_f) if val_f is not None and best_f is not None else None,
            }
        )
    return {
        "entries": enriched,
        "best": best,
        "best_val_bpb": best_f,
        "count": len(enriched),
    }

@tool
@agent_env.task
async def compare_experiments(
    title_a: str,
    title_b: str,
    memory_key: str = MEMORY_KEY_FANOUT,
) -> dict:
    """Compare two prior experiments side-by-side.

    Looks up both titles in the persisted leaderboard. For experiments run in
    the current session that are not yet persisted, use the values from your
    recent ``run_experiment`` tool results instead.

    Args:
        title_a: Title of the first experiment.
        title_b: Title of the second experiment.
        memory_key: Memory namespace (use the key from your directive).

    Returns:
        A dict with ``a``, ``b``, and ``delta_val_bpb`` (a minus b; negative
        means a is better).
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    entries: list[dict[str, Any]] = await memory.read_json.aio("memory/leaderboard.json", default=[])
    a = _find_leaderboard_entry(entries, title_a)
    b = _find_leaderboard_entry(entries, title_b)
    missing = [t for t, e in ((title_a, a), (title_b, b)) if e is None]
    delta: float | None = None
    if a is not None and b is not None and a.get("val_bpb") is not None and b.get("val_bpb") is not None:
        delta = float(a["val_bpb"]) - float(b["val_bpb"])
    return {
        "a": a,
        "b": b,
        "delta_val_bpb": delta,
        "missing": missing,
        "note": (
            "Some titles were not found in persisted memory; check recent run_experiment "
            "tool results in your transcript for the current session."
            if missing
            else None
        ),
    }

_CONFIG_FIELDS = {f.name for f in dataclasses.fields(ExperimentConfig)} - {"title"}
_RUN_TRAINING_DOC = re.compile(
    r"(def run_training\(config: ExperimentConfig\)[^:]*:\n(?:    \"\"\"[\s\S]*?\"\"\"\n))"
)

def normalize_train_py(text: str) -> str:
    return text.replace("\r\n", "\n").strip()

def baseline_train_py() -> str:
    """Return the repo baseline ``train.py`` (single source of truth for diffs)."""
    import train

    assert train.__file__ is not None
    return Path(train.__file__).read_text()

def filter_config_overrides(overrides: dict[str, Any] | None) -> dict[str, Any]:
    if not overrides:
        return {}
    filtered = {k: v for k, v in overrides.items() if k in _CONFIG_FIELDS}
    if "n_layer" in filtered:
        filtered["n_layer"] = max(1, min(int(filtered["n_layer"]), MAX_N_LAYER))
    if "n_head" in filtered:
        filtered["n_head"] = max(1, min(int(filtered["n_head"]), MAX_N_HEAD))
    if "n_embd" in filtered:
        filtered["n_embd"] = max(1, min(int(filtered["n_embd"]), MAX_N_EMBD))
    if "device_batch_size" in filtered:
        filtered["device_batch_size"] = max(1, min(int(filtered["device_batch_size"]), MAX_DEVICE_BATCH_SIZE))
    if "max_steps" in filtered:
        filtered["max_steps"] = max(1, min(int(filtered["max_steps"]), MAX_MAX_STEPS))
    if "n_embd" in filtered and "n_head" in filtered and int(filtered["n_embd"]) % int(filtered["n_head"]) != 0:
        head = int(filtered["n_head"])
        filtered["n_embd"] = (int(filtered["n_embd"]) // head) * head
    return filtered

def is_config_only_edit(train_py: str, overrides: dict[str, Any] | None) -> bool:
    """True when *train_py* differs from baseline only via ``config_overrides`` injection."""
    baseline = baseline_train_py()
    filtered = filter_config_overrides(overrides)
    if not filtered:
        return normalize_train_py(train_py) == normalize_train_py(baseline)
    expected = build_train_py_with_config_overrides(baseline, filtered)
    return normalize_train_py(train_py) == normalize_train_py(expected)

def experiment_config_signature(train_py: str, overrides: dict[str, Any] | None) -> str:
    """Stable hash of effective train code + config overrides for duplicate detection."""
    filtered = filter_config_overrides(overrides)
    payload = {
        "train_py": normalize_train_py(train_py),
        "overrides": sorted(filtered.items()),
    }
    return hashlib.sha256(json.dumps(payload, sort_keys=True, default=str).encode()).hexdigest()[:16]

async def check_duplicate_config(
    memory_key: str,
    title: str,
    train_py: str,
    overrides: dict[str, Any] | None,
) -> dict[str, Any] | None:
    """Return duplicate metadata if this config was already run under another title."""
    sig = experiment_config_signature(train_py, overrides)
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    sigs: dict[str, str] = await memory.read_json.aio("memory/config_signatures.json", default={})
    prior_title = sigs.get(sig)
    title_key = title.strip().lower()
    if prior_title and prior_title.strip().lower() != title_key:
        return {"duplicate_of": prior_title, "config_signature": sig}
    return None

async def register_config_signature(
    memory_key: str,
    title: str,
    train_py: str,
    overrides: dict[str, Any] | None,
    *,
    actor: str = "mle-autoresearch-code-agent",
) -> str:
    """Record the config signature for *title* after a successful edit or run."""
    sig = experiment_config_signature(train_py, overrides)
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    sigs: dict[str, str] = await memory.read_json.aio("memory/config_signatures.json", default={})
    sigs[sig] = title
    await memory.write_json.aio(
        "memory/config_signatures.json",
        sigs,
        actor=actor,
        reason=f"config signature for {title}",
    )
    await memory.save.aio()
    return sig

def build_train_py_with_config_overrides(
    base_code: str,
    overrides: dict[str, Any],
) -> str:
    """Inject ``dataclasses.replace(config, ...)`` at the top of ``run_training``."""
    filtered = filter_config_overrides(overrides)
    if not filtered:
        return base_code

    parts = [f"{k}={v!r}" for k, v in sorted(filtered.items())]
    injection = f"    import dataclasses\n    config = dataclasses.replace(config, {', '.join(parts)})\n"
    match = _RUN_TRAINING_DOC.search(base_code)
    if match:
        insert_at = match.end()
        return base_code[:insert_at] + injection + base_code[insert_at:]
    return base_code

async def load_config_overrides(memory_key: str, title: str) -> dict[str, Any]:
    """Load persisted ``ExperimentConfig`` overrides for an experiment title."""
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    slug = slugify(title)
    stored = await memory.read_json.aio(f"memory/config/{slug}.json", default={})
    if stored:
        return filter_config_overrides(stored)

    index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    title_lower = title.strip().lower()
    for entry in index:
        if str(entry.get("title", "")).strip().lower() == title_lower:
            slug = str(entry.get("slug", slug))
            stored = await memory.read_json.aio(f"memory/config/{slug}.json", default={})
            if stored:
                return filter_config_overrides(stored)
            return filter_config_overrides(entry.get("config_overrides") or {})
    return {}

def slugify(title: str) -> str:
    slug = re.sub(r"[^a-z0-9]+", "-", title.lower()).strip("-")
    return slug[:80] or "experiment"

async def load_train_code(memory_key: str, title: str) -> str:
    """Load edited ``train.py`` for *title*, falling back to the repo baseline."""
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    slug = slugify(title)
    saved = await memory.read_text.aio(f"memory/code/{slug}.py", default="")
    if saved.strip():
        return saved

    index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    title_lower = title.strip().lower()
    for entry in index:
        if str(entry.get("title", "")).strip().lower() == title_lower:
            slug = entry.get("slug", slug)
            saved = await memory.read_text.aio(f"memory/code/{slug}.py", default="")
            if saved.strip():
                return saved

    return baseline_train_py()

async def _global_best_val_bpb(memory: MemoryStore, *, exclude_title: str | None = None) -> float:
    """Lowest val_bpb recorded in memory (optionally excluding one title)."""
    exclude = (exclude_title or "").strip().lower()
    leaderboard: list[dict[str, Any]] = await memory.read_json.aio("memory/leaderboard.json", default=[])
    promising: list[dict[str, Any]] = await memory.read_json.aio("memory/promising_code.json", default=[])
    vals: list[float] = []
    for row in leaderboard + promising:
        if exclude and str(row.get("title", "")).strip().lower() == exclude:
            continue
        val = row.get("val_bpb")
        if val is not None:
            vals.append(float(val))
    return min(vals, default=float("inf"))

async def _update_promising_code(
    memory_key: str,
    *,
    title: str,
    slug: str,
    val_bpb: float,
    change_summary: str,
) -> None:
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    promising: list[dict[str, Any]] = await memory.read_json.aio("memory/promising_code.json", default=[])
    prior_best = await _global_best_val_bpb(memory, exclude_title=title)
    kept = val_bpb < prior_best
    promising.append(
        {
            "title": title,
            "slug": slug,
            "val_bpb": val_bpb,
            "kept": kept,
            "change_summary": change_summary,
            "recorded_at": datetime.now(timezone.utc).isoformat(timespec="seconds"),
        }
    )
    await memory.write_json.aio(
        "memory/promising_code.json",
        promising,
        actor="mle-autoresearch-code-agent",
        reason=f"promising code after {title} val_bpb={val_bpb}",
    )
    await memory.save.aio()

async def _resolve_train_py_for_edit(
    memory_key: str,
    spec: dict[str, Any],
) -> tuple[str, dict[str, Any], str | None]:
    """Build the effective ``train.py`` source and overrides for one edit spec."""
    train_py = spec.get("train_py", "")
    if not isinstance(train_py, str):
        train_py = ""
    config_overrides = filter_config_overrides(
        spec.get("config_overrides") or spec.get("config") or {}
    )
    parent_title = spec.get("parent_title")
    parent_title = str(parent_title).strip() if parent_title else None

    baseline = baseline_train_py()
    if config_overrides:
        base_code = await load_train_code(memory_key, parent_title) if parent_title else baseline
        if not train_py.strip() or normalize_train_py(train_py) == normalize_train_py(baseline):
            train_py = build_train_py_with_config_overrides(base_code, config_overrides)
        elif parent_title and normalize_train_py(train_py) == normalize_train_py(base_code):
            train_py = build_train_py_with_config_overrides(base_code, config_overrides)

    return train_py, config_overrides, parent_title

async def _persist_train_edits(
    memory_key: str,
    edits: list[dict[str, Any]],
    *,
    actor: str = "mle-autoresearch-code-agent",
) -> dict[str, Any]:
    """Save one or more ``train.py`` edits in a single memory transaction."""
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    saved: list[dict[str, Any]] = []
    errors: list[dict[str, Any]] = []
    now = datetime.now(timezone.utc).isoformat(timespec="seconds")

    for spec in edits:
        title = str(spec.get("title", "")).strip()
        change_summary = str(spec.get("change_summary", ""))
        if not title:
            errors.append({"title": title or "(missing)", "saved": False, "error": "title is required"})
            continue

        train_py, config_overrides, parent_title = await _resolve_train_py_for_edit(memory_key, spec)
        if not train_py.strip():
            errors.append(
                {
                    "title": title,
                    "saved": False,
                    "error": "train_py or config_overrides is required",
                }
            )
            continue
        if is_config_only_edit(train_py, config_overrides) and len(index) >= CONFIG_ONLY_EDIT_LIMIT:
            errors.append(
                {
                    "title": title,
                    "saved": False,
                    "error": (
                        f"Batch 2+ requires substantive train.py edits (LR schedule, optimizer, "
                        f"weight decay, grad clip, etc.), not config_overrides alone. "
                        f"You already have {len(index)} saved edit(s)."
                    ),
                }
            )
            continue
        if normalize_train_py(train_py) == normalize_train_py(baseline_train_py()) and not config_overrides:
            errors.append(
                {
                    "title": title,
                    "saved": False,
                    "error": (
                        "train.py matches baseline with no config_overrides; "
                        "pass config_overrides={n_layer: 6, ...} or edit run_training"
                    ),
                }
            )
            continue
        if "def run_training" not in train_py:
            errors.append(
                {
                    "title": title,
                    "saved": False,
                    "error": "train_py must define run_training(config) like the baseline train.py",
                }
            )
            continue

        slug = slugify(title)
        await memory.write_text.aio(
            f"memory/code/{slug}.py",
            train_py,
            actor=actor,
            reason=f"edit train.py for {title}",
        )
        if config_overrides:
            await memory.write_json.aio(
                f"memory/config/{slug}.json",
                config_overrides,
                actor=actor,
                reason=f"config overrides for {title}",
            )
        index.append(
            {
                "title": title,
                "slug": slug,
                "change_summary": change_summary,
                "lines": len(train_py.splitlines()),
                "edited_at": now,
                "config_overrides": config_overrides,
                "parent_title": parent_title,
            }
        )
        saved.append(
            {
                "saved": True,
                "title": title,
                "slug": slug,
                "lines": len(train_py.splitlines()),
                "change_summary": change_summary,
                "train_py": train_py,
                "config_overrides": config_overrides,
                "parent_title": parent_title,
                "memory_path": f"memory/code/{slug}.py",
            }
        )

    if saved:
        await memory.write_json.aio(
            "memory/code_index.json",
            index,
            actor=actor,
            reason=f"code index update ({len(saved)} edit(s))",
        )
        await memory.save.aio()

    return {
        "count": len(saved),
        "titles": [row["title"] for row in saved],
        "edits": saved,
        "errors": errors,
    }

@tool
@agent_env.task
async def get_baseline_train_code() -> dict:
    """Return the baseline ``train.py`` from the repo (the karpathy/autoresearch recipe).

    Use this once at the start to understand the starting point before editing.

    Returns:
        A dict with ``title``, ``train_py`` (full source), and ``lines``.
    """
    code = baseline_train_py()
    return {"title": "baseline", "train_py": code, "lines": len(code.splitlines())}

@tool
@agent_env.task
async def edit_train_code(
    title: str,
    train_py: str = "",
    change_summary: str = "",
    memory_key: str = MEMORY_KEY_FANOUT,
    config_overrides: dict[str, Any] | None = None,
    parent_title: str | None = None,
) -> dict:
    """Save an edited ``train.py`` for this experiment to agent memory.

    The code must keep a ``run_training(config: ExperimentConfig) -> ExperimentResult``
    entry point (same as the baseline). Only edit architecture, optimizer, and
    training-loop knobs inside the file.

    Alternatively pass ``config_overrides`` (e.g. ``{"n_layer": 6, "learning_rate": 1e-4}``)
    instead of a full ``train_py`` rewrite — the platform injects
    ``dataclasses.replace(config, ...)`` into ``run_training`` for you.

    Args:
        title: Short human-readable experiment name (used as the memory key slug).
        train_py: Full Python source for the edited training script (optional if
            ``config_overrides`` is set).
        change_summary: One-line description of what you changed and why.
        memory_key: Memory namespace from your directive.
        config_overrides: Optional ``ExperimentConfig`` field overrides.
        parent_title: Optional prior experiment to fork before applying overrides.

    Returns:
        Metadata about the saved edit, including the full ``train_py`` source
        (visible in the Flyte task output UI).
    """
    result = await _persist_train_edits(
        memory_key,
        [
            {
                "title": title,
                "train_py": train_py,
                "change_summary": change_summary,
                "config_overrides": config_overrides,
                "parent_title": parent_title,
            }
        ],
    )
    if result["edits"]:
        return result["edits"][0]
    err = result["errors"][0] if result["errors"] else {"saved": False, "error": "unknown error"}
    return err

@tool
@agent_env.task
async def edit_train_code_batch(
    edits: list[dict[str, Any]],
    memory_key: str = MEMORY_KEY_FANOUT,
) -> dict:
    """Save multiple edited ``train.py`` files in one atomic memory write.

    Use this when preparing a parallel experiment batch — avoids sequential
    ``edit_train_code`` calls and race conditions on ``memory/code_index.json``.

    Each item in ``edits`` must include ``title`` and ``change_summary``, plus either
    ``train_py`` (full source) or ``config_overrides`` (e.g. ``{"n_layer": 6}``).
    Optional ``parent_title`` forks a prior experiment before applying overrides.
    Every ``train_py`` must keep the ``run_training(config)`` entry point.

    Args:
        edits: List of edit specs, e.g.
            ``[{"title": "deeper-6L", "config_overrides": {"n_layer": 6}, "change_summary": "..."}]``.
        memory_key: Memory namespace from your directive.

    Returns:
        A dict with ``count``, ``titles``, ``edits`` (each includes ``train_py``),
        and ``errors`` (rejected).
    """
    if not edits:
        return {"count": 0, "titles": [], "edits": [], "errors": [{"error": "edits list is empty"}]}
    return await _persist_train_edits(
        memory_key,
        edits,
        actor="parallelized-autoresearch",
    )

@tool
@agent_env.task
async def read_train_code(title: str, memory_key: str = MEMORY_KEY_FANOUT) -> dict:
    """Read a previously saved ``train.py`` edit from memory (or the baseline).

    Args:
        title: Experiment title whose code you want to inspect.
        memory_key: Memory namespace from your directive.

    Returns:
        A dict with ``title``, ``train_py``, and ``lines``.
    """
    code = await load_train_code(memory_key, title)
    return {"title": title, "train_py": code, "lines": len(code.splitlines())}

@tool
@agent_env.task
async def get_promising_code(memory_key: str = MEMORY_KEY_FANOUT) -> dict:
    """Return promising ``train.py`` edits, the current best, and deltas vs best.

    Each entry records ``val_bpb`` after a successful run. Use ``read_train_code``
    with the best entry's title to inspect its source. Prefer ``get_code_edit_history``
    for the full cross-session table of edits, results, and regressions.

    Args:
        memory_key: Memory namespace from your directive.

    Returns:
        A dict with ``entries``, ``best``, ``best_val_bpb``, and ``count``.
    """
    history = await load_research_history(memory_key)
    best_val = history.get("best_val_bpb")
    entries: list[dict[str, Any]] = []
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    promising: list[dict[str, Any]] = await memory.read_json.aio("memory/promising_code.json", default=[])
    for row in promising:
        val = row.get("val_bpb")
        val_f = float(val) if val is not None else None
        entries.append(
            {
                **row,
                "beat_best": val_f is not None and best_val is not None and val_f <= best_val,
                "delta_vs_best": (val_f - best_val) if val_f is not None and best_val is not None else None,
            }
        )
    best: dict[str, Any] | None = None
    if history.get("best_title"):
        best_key = str(history["best_title"]).strip().lower()
        for entry in reversed(entries):
            if str(entry.get("title", "")).strip().lower() == best_key:
                best = entry
                break
    return {
        "entries": entries,
        "best": best,
        "best_val_bpb": best_val,
        "best_title": history.get("best_title"),
        "count": len(entries),
    }

@tool
@agent_env.task
async def get_code_edit_history(memory_key: str = MEMORY_KEY_FANOUT) -> dict:
    """Return all prior code edits, run results, and whether each beat the current best.

    Call this at the start of a session when ``memory_key`` already has experiments.
    Shows every saved ``train.py`` edit, its ``change_summary``, ``val_bpb`` (if run),
    ``delta_vs_best`` (negative means better), ``outcome`` (``new_best`` / ``regression`` /
    ``failed`` / ``not_run``), and linked hypotheses.

    Args:
        memory_key: Memory namespace from your directive.

    Returns:
        A dict with ``best_val_bpb``, ``best_title``, ``trials``, and summary counts.
    """
    return await load_research_history(memory_key)

async def load_saved_code_edits(memory_key: str) -> list[dict[str, Any]]:
    """Load all saved ``train.py`` edits from memory for reporting."""
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    promising: list[dict[str, Any]] = await memory.read_json.aio("memory/promising_code.json", default=[])
    val_by_title = {
        str(row.get("title", "")).strip().lower(): row.get("val_bpb")
        for row in promising
        if row.get("val_bpb") is not None
    }
    kept_titles = {
        str(row.get("title", "")).strip().lower()
        for row in promising
        if row.get("kept")
    }

    baseline = baseline_train_py()
    edits: list[dict[str, Any]] = []
    for entry in index:
        slug = str(entry.get("slug", slugify(str(entry.get("title", "")))))
        train_py = await memory.read_text.aio(f"memory/code/{slug}.py", default="")
        title = str(entry.get("title", ""))
        title_key = title.strip().lower()
        config_overrides = filter_config_overrides(entry.get("config_overrides") or {})
        if not config_overrides:
            config_overrides = filter_config_overrides(
                await memory.read_json.aio(f"memory/config/{slug}.json", default={})
            )
        if config_overrides and normalize_train_py(train_py) == normalize_train_py(baseline):
            parent_title = entry.get("parent_title")
            base_code = (
                await load_train_code(memory_key, str(parent_title))
                if parent_title
                else baseline
            )
            train_py = build_train_py_with_config_overrides(base_code, config_overrides)
        edits.append(
            {
                **entry,
                "slug": slug,
                "train_py": train_py,
                "config_overrides": config_overrides,
                "memory_path": f"memory/code/{slug}.py",
                "val_bpb": val_by_title.get(title_key),
                "kept": title_key in kept_titles,
            }
        )
    return edits

async def record_experiment_result(
    memory_key: str,
    result: dict[str, Any],
    *,
    actor: str = "mle-autoresearch-code-agent",
) -> None:
    """Upsert one experiment outcome into ``memory/leaderboard.json``."""
    title = str(result.get("title", "")).strip()
    if not title:
        return
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    leaderboard: list[dict[str, Any]] = await memory.read_json.aio("memory/leaderboard.json", default=[])
    row: dict[str, Any] = {
        "title": title,
        "success": bool(result.get("success")),
        "val_bpb": float(result["val_bpb"]) if result.get("val_bpb") is not None else None,
        "model_name": result.get("model_name"),
        "n_params": result.get("n_params"),
        "steps": int(result["steps"]) if result.get("steps") is not None else None,
        "resources": result.get("resources"),
        "oom_retries": int(result.get("oom_retries", 0)),
    }
    if not result.get("success"):
        err = result.get("error") or result.get("stderr") or "failed"
        row["error"] = str(err)[:200]

    title_key = title.lower()
    replaced = False
    for idx, existing in enumerate(leaderboard):
        if str(existing.get("title", "")).strip().lower() == title_key:
            leaderboard[idx] = row
            replaced = True
            break
    if not replaced:
        leaderboard.append(row)

    await memory.write_json.aio(
        "memory/leaderboard.json",
        leaderboard,
        actor=actor,
        reason=f"experiment result for {title}",
    )
    await memory.save.aio()

async def record_promising_run(
    memory_key: str,
    title: str,
    result: dict[str, Any],
    change_summary: str = "",
) -> None:
    """Persist a successful run's code to the promising-code ledger."""
    if not result.get("success") or result.get("val_bpb") is None:
        return
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    code_index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    summary = change_summary
    slug = slugify(title)
    for entry in reversed(code_index):
        if str(entry.get("title", "")).strip().lower() == title.strip().lower():
            summary = summary or str(entry.get("change_summary", ""))
            slug = str(entry.get("slug", slug))
            break
    await _update_promising_code(
        memory_key,
        title=title,
        slug=slug,
        val_bpb=float(result["val_bpb"]),
        change_summary=summary or "successful run",
    )

@tool
@agent_env.task
async def record_batch_plan(
    batch_id: str,
    experiments: list[dict[str, Any]],
    memory_key: str = MEMORY_KEY_FANOUT,
) -> dict:
    """Persist a batch of planned experiments before editing or running them.

    Each experiment dict should include at least ``title`` and ``hypothesis``.
    Optional keys: ``expected_effect``, ``change_summary``, ``parent_title``.

    Args:
        batch_id: Short identifier for this batch (e.g. ``batch-1-depth-sweep``).
        experiments: Planned experiment specs for parallel execution.
        memory_key: Memory namespace from your directive.

    Returns:
        The saved batch record with ``batch_id``, ``count``, and ``experiments``.
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    batches: list[dict[str, Any]] = await memory.read_json.aio("memory/batches.json", default=[])
    record = {
        "batch_id": batch_id,
        "experiments": experiments,
        "count": len(experiments),
        "status": "planned",
        "created_at": datetime.now(timezone.utc).isoformat(timespec="seconds"),
    }
    batches.append(record)
    await memory.write_json.aio(
        "memory/batches.json",
        batches,
        actor="parallelized-autoresearch",
        reason=f"batch plan {batch_id}",
    )
    await memory.save.aio()
    return record

@tool
@agent_env.task
async def get_batch_plan(batch_id: str, memory_key: str = MEMORY_KEY_FANOUT) -> dict:
    """Load a previously recorded batch plan by ``batch_id``.

    Args:
        batch_id: Identifier passed to ``record_batch_plan``.
        memory_key: Memory namespace from your directive.

    Returns:
        The batch record, or ``{"found": False}`` if missing.
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    batches: list[dict[str, Any]] = await memory.read_json.aio("memory/batches.json", default=[])
    batch_id_lower = batch_id.strip().lower()
    for batch in reversed(batches):
        if str(batch.get("batch_id", "")).strip().lower() == batch_id_lower:
            return {"found": True, **batch}
    return {"found": False, "batch_id": batch_id}

@tool
@agent_env.task
async def record_batch_hypotheses(
    experiments: list[dict[str, Any]],
    memory_key: str = MEMORY_KEY_FANOUT,
) -> dict:
    """Record hypotheses for every experiment in a batch (before ``run_experiment_batch``).

    Each item needs ``title``, ``hypothesis``, and ``expected_effect``.

    Args:
        experiments: List of hypothesis dicts (one per planned experiment title).
        memory_key: Memory namespace from your directive.

    Returns:
        A dict with ``recorded`` count and the appended entries.
    """
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    prior: list[dict[str, Any]] = await memory.read_json.aio("memory/hypotheses.json", default=[])
    recorded: list[dict[str, Any]] = []
    for spec in experiments:
        entry = HypothesisEntry(
            title=str(spec.get("title", "")),
            hypothesis=str(spec.get("hypothesis", "")),
            expected_effect=str(spec.get("expected_effect", "")),
            recorded_at=datetime.now(timezone.utc).isoformat(timespec="seconds"),
        )
        row = dataclasses.asdict(entry)
        prior.append(row)
        recorded.append(row)
    await memory.write_json.aio(
        "memory/hypotheses.json",
        prior,
        actor="parallelized-autoresearch",
        reason=f"batch hypotheses ({len(recorded)} experiments)",
    )
    await memory.save.aio()
    return {"recorded": len(recorded), "entries": recorded}

def evaluate_batch_results_impl(
    results: list[dict[str, Any]],
    batch_id: str = "",
) -> dict[str, Any]:
    """Rank and summarize the outcome of a parallel experiment batch."""
    successes: list[dict[str, Any]] = []
    failures: list[dict[str, Any]] = []
    for result in results:
        if not isinstance(result, dict):
            failures.append({"title": "?", "error": str(result)})
            continue
        if result.get("success") and result.get("val_bpb") is not None:
            successes.append(result)
        else:
            failures.append(
                {
                    "title": result.get("title", "?"),
                    "error": result.get("error") or (result.get("stderr") or "")[:200],
                    "oom": result.get("oom", False),
                }
            )

    ranked = sorted(successes, key=lambda r: float(r["val_bpb"]))
    best = ranked[0] if ranked else None
    return {
        "batch_id": batch_id or None,
        "total": len(results),
        "n_success": len(successes),
        "n_failed": len(failures),
        "ranked": [
            {
                "title": r.get("title"),
                "val_bpb": r.get("val_bpb"),
                "model_name": r.get("model_name"),
                "steps": r.get("steps"),
                "resources": r.get("resources"),
                "oom_retries": r.get("oom_retries", 0),
            }
            for r in ranked
        ],
        "best": best,
        "failures": failures,
    }

@tool
@agent_env.task
async def evaluate_batch_results(
    results: list[dict[str, Any]],
    batch_id: str = "",
) -> dict:
    """Rank and summarize the outcome of a parallel experiment batch.

    Use after ``run_experiment_batch`` or ``flyte_map("run_experiment", ...)``.
    Lower ``val_bpb`` is better.

    Args:
        results: List of ``run_experiment`` result dicts (same order as titles).
        batch_id: Optional batch label for the summary.

    Returns:
        A dict with ``successes``, ``failures``, ``ranked``, ``best``, and ``batch_id``.
    """
    return evaluate_batch_results_impl(results, batch_id=batch_id)

async def persist_run_results_to_leaderboard(
    memory_key: str,
    results: list[dict[str, Any]],
    *,
    actor: str = "parallelized-autoresearch",
) -> int:
    """Persist run results (success or failure) to ``memory/leaderboard.json``."""
    added = 0
    for result in results:
        if not isinstance(result, dict) or not result.get("title"):
            continue
        await record_experiment_result(memory_key, result, actor=actor)
        added += 1
    return added

async def run_experiment_batch_impl(
    run_experiment_task: Any,
    titles: list[str],
    *,
    time_budget_sec: int = 45,
    memory_key: str = MEMORY_KEY_FANOUT,
    concurrency: int = 4,
    group_name: str | None = None,
) -> dict[str, Any]:
    """Fan out ``run_experiment`` across *titles* via ``flyte.map``."""
    if not titles:
        return {"batch_size": 0, "results": [], "titles": []}

    n = len(titles)
    budgets = [time_budget_sec] * n
    keys = [memory_key] * n
    map_kwargs: dict[str, Any] = {"concurrency": concurrency, "return_exceptions": True}
    if group_name:
        map_kwargs["group_name"] = group_name

    results: list[Any] = []
    async for item in flyte.map.aio(run_experiment_task, titles, budgets, keys, **map_kwargs):
        if isinstance(item, BaseException):
            results.append({"success": False, "title": "?", "error": str(item)})
        else:
            results.append(item)

    return {
        "batch_size": n,
        "titles": titles,
        "results": results,
        "concurrency": concurrency,
        "group_name": group_name,
    }

OOM_MARKERS = (
    "out of memory",
    "oom",
    "cannot allocate memory",
    "can't allocate memory",
    "unable to allocate",
    "memoryerror",
    "killed",
    "signal 9",
    "std::bad_alloc",
    "defaultcpuallocator",
    "bad_alloc",
)

def is_oom(stderr: str, returncode: int | None, *, stdout: str = "") -> bool:
    """Detect OOM from sandbox stderr / exit code (137 = SIGKILL/OOM-kill)."""
    if returncode in (137, -9):
        return True
    text = f"{stderr}\n{stdout}".lower()
    return any(marker in text for marker in OOM_MARKERS)

def parse_metrics(stdout: str) -> dict[str, Any] | None:
    """Parse the ``AUTORESEARCH_METRICS=`` line emitted by the driver script."""
    for line in stdout.splitlines():
        if line.startswith("AUTORESEARCH_METRICS="):
            return json.loads(line.split("=", 1)[1])
    return None

def write_driver_script(title: str, time_budget_sec: int, eval_tokens: int) -> str:
    """Return a small driver that imports the agent-edited ``train.py`` and prints metrics."""
    return textwrap.dedent(
        f'''
        import json
        import os
        import sys

        workdir = os.path.dirname(os.path.abspath(__file__))
        os.chdir(workdir)
        os.environ["AUTORESEARCH_CACHE"] = workdir
        sys.path.insert(0, workdir)
        os.environ.setdefault("AUTORESEARCH_EVAL_TOKENS", "{eval_tokens}")

        from autoresearch_types import ExperimentConfig
        import train

        overrides = {{}}
        overrides_path = os.path.join(workdir, "config_overrides.json")
        if os.path.exists(overrides_path):
            with open(overrides_path) as f:
                overrides = json.load(f)

        config = ExperimentConfig(title={title!r}, time_budget_sec={time_budget_sec})
        if overrides:
            import dataclasses
            config = dataclasses.replace(config, **overrides)
        result = train.run_training(config)
        payload = {{
            "title": result.title,
            "val_bpb": round(result.val_bpb, 6),
            "model_name": result.model_name,
            "n_params": result.n_params,
            "steps": result.steps,
            "device": result.device,
            "notes": result.notes,
        }}
        print("AUTORESEARCH_METRICS=" + json.dumps(payload))
        '''
    ).strip()

def stage_sandbox_files(
    work_dir: str,
    train_py: str,
    *,
    title: str,
    time_budget_sec: int,
    eval_tokens: int | None = None,
    config_overrides: dict[str, Any] | None = None,
) -> Path:
    """Copy support modules + edited train code into the sandbox work directory."""
    import autoresearch_types
    import prepare

    if eval_tokens is None:
        eval_tokens = 32 * prepare.MAX_SEQ_LEN
    root = Path(work_dir)
    root.mkdir(parents=True, exist_ok=True)
    (root / "train.py").write_text(train_py)
    if config_overrides:
        (root / "config_overrides.json").write_text(json.dumps(config_overrides))
    (root / "prepare.py").write_text(Path(prepare.__file__).read_text())
    (root / "autoresearch_types.py").write_text(Path(autoresearch_types.__file__).read_text())
    driver = write_driver_script(title, time_budget_sec, eval_tokens)
    driver_path = root / "driver.py"
    driver_path.write_text(driver)
    return driver_path

async def run_train_in_sandbox(
    work_dir: str,
    train_py: str,
    *,
    title: str,
    time_budget_sec: int,
    config_overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """Execute ``train.py`` via ``async with sb.on_device.session(backend='userns')``."""
    from union import sandbox as sb

    driver_path = stage_sandbox_files(
        work_dir,
        train_py,
        title=title,
        time_budget_sec=time_budget_sec,
        config_overrides=config_overrides,
    )
    timeout_s = max(time_budget_sec + 180, 300)

    try:
        async with sb.on_device.session(backend="userns", host_work_dir=work_dir) as sbx:
            proc = await sbx.run(
                f"python {driver_path}",
                stdout=True,
                stderr=True,
                network_mode="blocked",
                timeout_s=timeout_s,
            )
            stdout, stderr = await proc.communicate_text()
    except Exception as exc:
        err_text = str(exc)
        oom = is_oom(err_text, None)
        return {
            "success": False,
            "oom": oom,
            "title": title,
            "exit_code": None,
            "stdout_tail": "",
            "stderr": err_text,
            "error": (
                "Training run was OOM-killed; the platform will retry with more memory."
                if oom
                else f"Sandbox execution failed: {err_text}"
            ),
        }

    metrics = parse_metrics(stdout or "")
    oom = is_oom(stderr or "", proc.returncode, stdout=stdout or "")

    if metrics is not None and proc.returncode == 0:
        return {
            "success": True,
            "oom": False,
            **metrics,
            "exit_code": proc.returncode,
            "stderr_tail": (stderr or "")[-800:],
        }

    return {
        "success": False,
        "oom": oom,
        "title": title,
        "exit_code": proc.returncode,
        "stdout_tail": (stdout or "")[-1500:],
        "stderr": stderr or "",
        "error": (
            "Training run was OOM-killed; the platform will retry with more memory."
            if oom
            else f"Training failed (exit {proc.returncode}). See stderr for details."
        ),
    }

def _title_key(title: str) -> str:
    return str(title or "").strip().lower()

def _best_from_entries(entries: list[dict[str, Any]]) -> tuple[float | None, str | None]:
    best_val: float | None = None
    best_title: str | None = None
    for row in entries:
        val = row.get("val_bpb")
        if val is None:
            continue
        fval = float(val)
        if best_val is None or fval < best_val:
            best_val = fval
            best_title = str(row.get("title", ""))
    return best_val, best_title

def _latest_by_title(rows: list[dict[str, Any]], *, title_field: str = "title") -> dict[str, dict[str, Any]]:
    out: dict[str, dict[str, Any]] = {}
    for row in rows:
        key = _title_key(str(row.get(title_field, "")))
        if key:
            out[key] = row
    return out

def _outcome_label(
    *,
    val_bpb: float | None,
    success: bool | None,
    best_val: float | None,
) -> str:
    if success is False or (val_bpb is None and success is not True):
        return "failed"
    if val_bpb is None:
        return "not_run"
    if best_val is None:
        return "ran"
    delta = float(val_bpb) - best_val
    if delta <= 0:
        return "new_best"
    return "regression"

def _vs_best_text(val_bpb: float | None, best_val: float | None) -> str:
    if val_bpb is None or best_val is None:
        return "—"
    delta = float(val_bpb) - best_val
    if abs(delta) < 1e-12:
        return "0 (ties best)"
    sign = "+" if delta > 0 else ""
    quality = "worse" if delta > 0 else "better"
    return f"{sign}{delta:.6g} ({quality})"

async def load_research_history(memory_key: str) -> dict[str, Any]:
    """Merge saved edits, run results, and outcomes for cross-session agent context."""
    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    code_index: list[dict[str, Any]] = await memory.read_json.aio("memory/code_index.json", default=[])
    leaderboard: list[dict[str, Any]] = await memory.read_json.aio("memory/leaderboard.json", default=[])
    promising: list[dict[str, Any]] = await memory.read_json.aio("memory/promising_code.json", default=[])
    hypotheses: list[dict[str, Any]] = await memory.read_json.aio("memory/hypotheses.json", default=[])

    lb_by_title = _latest_by_title(leaderboard)
    prom_by_title = _latest_by_title(promising)
    hyp_by_title = _latest_by_title(hypotheses)

    best_val, best_title = _best_from_entries(leaderboard)
    if best_val is None:
        best_val, best_title = _best_from_entries(promising)

    trials: list[dict[str, Any]] = []
    seen: set[str] = set()

    for edit in code_index:
        title = str(edit.get("title", ""))
        key = _title_key(title)
        if not key:
            continue
        seen.add(key)
        lb = lb_by_title.get(key, {})
        prom = prom_by_title.get(key, {})
        hyp = hyp_by_title.get(key, {})

        val = lb.get("val_bpb")
        if val is None:
            val = prom.get("val_bpb")
        val_f = float(val) if val is not None else None

        success = lb.get("success")
        if success is None and val_f is not None:
            success = True
        if success is None and lb.get("error"):
            success = False

        beat_best = val_f is not None and best_val is not None and val_f <= best_val
        trials.append(
            {
                "title": title,
                "change_summary": edit.get("change_summary") or prom.get("change_summary") or "",
                "edited_at": edit.get("edited_at"),
                "val_bpb": val_f,
                "model_name": lb.get("model_name"),
                "success": success,
                "error": lb.get("error"),
                "hypothesis": hyp.get("hypothesis"),
                "expected_effect": hyp.get("expected_effect"),
                "beat_best": beat_best,
                "delta_vs_best": (val_f - best_val) if val_f is not None and best_val is not None else None,
                "vs_best": _vs_best_text(val_f, best_val),
                "outcome": _outcome_label(val_bpb=val_f, success=success, best_val=best_val),
                "kept": bool(prom.get("kept")),
            }
        )

    for key, lb in lb_by_title.items():
        if key in seen:
            continue
        val = lb.get("val_bpb")
        val_f = float(val) if val is not None else None
        success = lb.get("success")
        if success is None and val_f is not None:
            success = True
        trials.append(
            {
                "title": lb.get("title", key),
                "change_summary": "",
                "edited_at": None,
                "val_bpb": val_f,
                "model_name": lb.get("model_name"),
                "success": success,
                "error": lb.get("error"),
                "hypothesis": hyp_by_title.get(key, {}).get("hypothesis"),
                "expected_effect": hyp_by_title.get(key, {}).get("expected_effect"),
                "beat_best": val_f is not None and best_val is not None and val_f <= best_val,
                "delta_vs_best": (val_f - best_val) if val_f is not None and best_val is not None else None,
                "vs_best": _vs_best_text(val_f, best_val),
                "outcome": _outcome_label(val_bpb=val_f, success=success, best_val=best_val),
                "kept": bool(prom_by_title.get(key, {}).get("kept")),
            }
        )

    trials.sort(key=lambda t: (t.get("edited_at") or "", t.get("title", "")))

    return {
        "memory_key": memory_key,
        "best_val_bpb": best_val,
        "best_title": best_title,
        "trials": trials,
        "count_edits": len(code_index),
        "count_runs": sum(1 for t in trials if t.get("val_bpb") is not None or t.get("success") is False),
        "count_regressions": sum(1 for t in trials if t.get("outcome") == "regression"),
        "count_new_best": sum(1 for t in trials if t.get("outcome") == "new_best"),
    }

def format_research_history_for_directive(history: dict[str, Any], *, max_rows: int = 20) -> str:
    """Render prior edits/results as a compact block for the run directive."""
    trials: list[dict[str, Any]] = history.get("trials") or []
    if not trials:
        return ""

    best_val = history.get("best_val_bpb")
    best_title = history.get("best_title")
    header = "\n\n## Prior research (from memory — continue, do not repeat)\n"
    if best_val is not None:
        header += f"Current best: **val_bpb={best_val:.6g}** ({best_title}). Lower is better.\n"
    else:
        header += "No successful runs recorded yet.\n"

    header += (
        "Call ``get_code_edit_history()`` at the start to refresh this table. "
        "Use ``read_train_code`` on the best title to fork winners.\n\n"
    )

    lines = [
        "| Title | Change | val_bpb | vs best | Outcome |",
        "| --- | --- | --- | --- | --- |",
    ]
    for trial in trials[-max_rows:]:
        title = str(trial.get("title", ""))
        change = str(trial.get("change_summary", ""))[:72]
        val = trial.get("val_bpb")
        val_s = f"{float(val):.6g}" if val is not None else ("failed" if trial.get("success") is False else "—")
        lines.append(
            f"| {title} | {change} | {val_s} | {trial.get('vs_best', '—')} | {trial.get('outcome', '—')} |"
        )

    omitted = len(trials) - max_rows
    footer = ""
    if omitted > 0:
        footer = f"\n({omitted} older trial(s) omitted — use get_code_edit_history for the full list.)\n"

    return header + "\n".join(lines) + footer
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/parallelized_autoresearch/tools.py*

`right_size` is the pre-built handler passed to `@tool(call_handler=...)`. The agent does not need a back-reference to the `Agent` instance — the harness passes `call_llm` and `tool_fn.model` into the handler on each invocation.

The experiment task stacks `@tool(call_handler=tools.right_size)` on `@experiment_env.task`. The task body only loads edited code and runs sandbox training; sizing and OOM recovery happen in the handler:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.5",
#    "litellm",
#    "httpx",
#    "pydantic-monty",
#    "unionai-sandbox[flyte]",
#    "torch",
#    "numpy",
#    "pyarrow",
#    "requests",
#    "tiktoken",
#    "rustbpe",
# ]
# main = "parallelized_autoresearch"
# params = "--n-experiments 6 --batch-size 3 --num-shards 1"
# ///
"""Parallelized autoresearch agent — code-mode MLE agent with batched sandbox experiments."""

from __future__ import annotations

import dataclasses
from typing import Any

import flyte
import flyte.report
from flyte.ai.agents import Agent, MemoryStore, agent_progress_cb, tool

from autoresearch_types import AutoresearchOutput, DEFAULT_MAX_STEPS, DEFAULT_NUM_SHARDS, MAX_DEVICE_BATCH_SIZE, MAX_N_EMBD, MAX_N_HEAD, MAX_N_LAYER
from bundle import agent_env, build_bundle, experiment_env, materialize_cache, profile_bundle
import tools
import ui

MODEL = "claude-sonnet-4-6"

# {{docs-fragment run_experiment}}
@tool(call_handler=tools.right_size)
@experiment_env.task
async def run_experiment(
    title: str,
    time_budget_sec: int = 45,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
) -> dict:
    """Train using agent-edited ``train.py`` with LLM right-sizing and OOM self-healing."""
    train_py = await tools.load_train_code(memory_key, title)
    config_overrides = await tools.load_config_overrides(memory_key, title)
    duplicate = await tools.check_duplicate_config(memory_key, title, train_py, config_overrides)
    if duplicate:
        result = {
            "success": False,
            "title": title,
            "error": (
                f"Duplicate config of '{duplicate['duplicate_of']}' "
                f"(signature {duplicate['config_signature']}); change train.py or overrides."
            ),
            "duplicate_of": duplicate["duplicate_of"],
        }
        await tools.record_experiment_result(
            memory_key,
            result,
            actor="parallelized-autoresearch",
        )
        return result
    bundle = await build_bundle()
    cache_dir = await materialize_cache(bundle)
    result = await tools.run_train_in_sandbox(
        cache_dir,
        train_py,
        title=title,
        time_budget_sec=time_budget_sec,
        config_overrides=config_overrides or None,
    )
    if result.get("success"):
        await tools.record_promising_run(memory_key, title, result)
        await tools.register_config_signature(
            memory_key,
            title,
            train_py,
            config_overrides,
            actor="parallelized-autoresearch",
        )
    await tools.record_experiment_result(
        memory_key,
        result,
        actor="parallelized-autoresearch",
    )
    return result

# ``flyte.map`` invokes ``run_experiment.aio`` directly (not through the agent
# registry), so bind the LLM callback and model here for ``call_handler`` right-sizing.
run_experiment = dataclasses.replace(
    run_experiment,
    call_llm=tools.call_llm,
    model=MODEL,
)
# {{/docs-fragment run_experiment}}

@tool
@agent_env.task
async def run_experiment_batch(
    titles: list[str],
    time_budget_sec: int = 45,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
    concurrency: int = 4,
    batch_id: str = "",
) -> dict:
    """Run multiple ``run_experiment`` calls in parallel via ``flyte.map``.

    Prefer this over hand-rolling ``flyte_map`` when you already have a list of
    experiment titles with saved ``train.py`` edits.

    Args:
        titles: Experiment titles whose code was saved with ``edit_train_code_batch``.
        time_budget_sec: Wall-clock budget passed to each run.
        memory_key: Memory namespace from your directive.
        concurrency: Max parallel sandbox runs (default 4).
        batch_id: Optional label attached to the returned batch metadata.

    Returns:
        A dict with ``batch_size``, ``titles``, ``results``, and ``evaluation``
        (from :func:`evaluate_batch_results`).
    """
    group = batch_id or f"batch-{len(titles)}"
    payload = await tools.run_experiment_batch_impl(
        run_experiment,
        titles,
        time_budget_sec=time_budget_sec,
        memory_key=memory_key,
        concurrency=concurrency,
        group_name=group,
    )
    payload["evaluation"] = tools.evaluate_batch_results_impl(payload["results"], batch_id=batch_id)
    await tools.persist_run_results_to_leaderboard(memory_key, payload["results"])
    return payload

INSTRUCTIONS = f"""\
You are a senior ML-engineer agent running karpathy/autoresearch-style research by
**editing train.py** and **batching parallel experiments**. Your goal: MINIMIZE
val_bpb (LOWER is better).

You operate in CODE MODE. Each turn, write ONE ```python``` block that calls the
available functions, OR reply in plain text when finished. The last expression in
your code block is returned as the observation.

Core tools:
- get_code_edit_history — **call first on resumed sessions**: prior edits, val_bpb, vs-best deltas
- get_baseline_train_code, edit_train_code_batch, read_train_code, get_promising_code
- inspect_dataset, search_arxiv
- get_leaderboard, compare_experiments

Saving edits (required for visible diffs and distinct runs):
- **Batch 1 only:** you may use ``config_overrides`` for a quick architecture/LR sweep via
  ``edit_train_code_batch(edits=[{{"title": "...", "config_overrides": {{"n_layer": 6}}, "change_summary": "..."}}])``.
- **Batch 2 and later:** every edit must include a **substantive ``train_py`` change**
  (learning-rate schedule, optimizer/weight_decay, grad clipping, warmup, etc.).
  ``config_overrides`` alone is **rejected** after the first batch — fork with
  ``parent_title`` and edit the training loop in ``train_py``.
- ``config_overrides`` fields: ``n_layer``, ``n_head``, ``n_embd``, ``dropout``,
  ``device_batch_size``, ``learning_rate``, ``time_budget_sec``, ``max_steps``.
- To fork a winner: set ``parent_title`` to the best title, then edit ``train_py``.
- Do **not** save baseline ``train.py`` without overrides — the platform rejects identical edits.
- Duplicate configs (same effective train.py + overrides) are rejected at run time.

Training budget (fair comparison across architectures):
- Default **max_steps={DEFAULT_MAX_STEPS}** with **time_budget_sec=45** as a safety cap.
  All models train for the same step count unless they hit the wall-clock limit.
- Check ``steps`` in batch results — if a run stopped early on time, the model may be too large.

Batch / fan-out tools:
- record_batch_plan(batch_id, experiments) — persist a multi-experiment plan
- get_batch_plan(batch_id) — reload a plan
- record_batch_hypotheses(experiments) — write hypotheses for every title in a batch
- edit_train_code_batch(edits) — save all ``train.py`` edits in one memory transaction
- run_experiment_batch(titles, concurrency=...) — parallel sandbox runs (LLM right-sized; OOM-healed)
- evaluate_batch_results(results, batch_id=...) — rank successes vs failures

Typical batch loop (aim for **≤8 code turns** before your plain-text summary):
0. If prior research exists in your directive, ``get_code_edit_history()`` then
   ``read_train_code(best_title)`` before planning new batches.
1. Turn 1: ``get_baseline_train_code()`` + ``inspect_dataset()``.
2. Turn 2: ``record_batch_plan`` then ``edit_train_code_batch(edits=[...])`` for the whole batch.
3. Turn 3: ``record_batch_hypotheses`` + ``run_experiment_batch(titles, concurrency=...)``.
4. Turn 4+: fork winners into the next batch with **train.py** edits, or reply in plain text when done.

Batch diversity (required):
- Every title in a batch must test a **distinct hypothesis** — no duplicate configs or renames.
- **Spread axes across the batch**: e.g. one edit tweaks depth/width, another changes the
  **training loop** (cosine LR, AdamW betas, weight decay), another regularization or batch size.
- Avoid LR micro-sweeps (±30% of the current best LR) after batch 1 — those rarely beat a plateau.
- Vary **one or two knobs per edit**; state the change in ``change_summary`` and
  ``record_batch_hypotheses``.
- Use ``evaluate_batch_results`` to see **which axis** helped, then explore under-tested axes.

Plateau rule (required):
- If **3 consecutive batches** fail to beat the global best val_bpb by more than **0.01**,
  stop hyperparameter micro-sweeps. Switch to **training-loop code edits** in ``train.py``
  (scheduler, optimizer, regularization, data/loss changes).

Rules:
- Use ``edit_train_code_batch`` for all code saves (including a single title: ``edits=[{{...}}]``).
- Every edit must keep ``run_training(config: ExperimentConfig) -> ExperimentResult``.
- Do NOT size compute — each run is LLM right-sized and retried automatically on OOM.
- Workshop limits: n_layer<={MAX_N_LAYER}, n_embd<={MAX_N_EMBD}, n_head<={MAX_N_HEAD},
  device_batch_size<={MAX_DEVICE_BATCH_SIZE}, seq_len=512.
- Monty sandbox: no imports, no dict mutation, no augmented assignment (`+=`).
- **Always finish with plain text (no code block)** once you have results to report.
"""

DEFAULT_MAX_TURNS = 50

def build_fanout_agent(*, max_turns: int = DEFAULT_MAX_TURNS) -> Agent:
    """Construct the fan-out agent (``code_mode=True``) with a configurable turn budget."""
    return Agent(
        name="parallelized-autoresearch",
        instructions=INSTRUCTIONS,
        model=MODEL,
        tools=[
            tools.search_arxiv,
            tools.inspect_dataset,
            tools.get_baseline_train_code,
            tools.get_code_edit_history,
            tools.edit_train_code_batch,
            tools.read_train_code,
            tools.get_promising_code,
            tools.get_leaderboard,
            tools.compare_experiments,
            tools.record_batch_plan,
            tools.get_batch_plan,
            tools.record_batch_hypotheses,
            run_experiment_batch,
            tools.evaluate_batch_results,
        ],
        max_turns=max_turns,
        call_llm=tools.call_llm,
        code_mode=True,
    )

# {{docs-fragment agent}}
@agent_env.task(report=True)
async def parallelized_autoresearch(
    n_experiments: int = 6,
    num_shards: int = DEFAULT_NUM_SHARDS,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
    batch_size: int = 3,
    max_turns: int = DEFAULT_MAX_TURNS,
) -> AutoresearchOutput:
    """Drive the fan-out code-edit MLE agent with sandbox batch execution."""
    bundle = await build_bundle(num_shards=num_shards)
    profile = await profile_bundle(bundle)

    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    persisted = await memory.read_json.aio("memory/leaderboard.json", default=[])
    promising = await memory.read_json.aio("memory/promising_code.json", default=[])
    history = await tools.load_research_history(memory_key)
    flyte.logger.info(
        "Fan-out agent restored %d messages, %d experiments, %d promising edits, best val_bpb=%s.",
        len(memory.messages),
        len(persisted),
        len(promising),
        history.get("best_val_bpb"),
    )

    events: list[dict[str, Any]] = []

    async def on_event(ev) -> None:
        events.append({"type": ev.type, "data": ev.data})
        if ev.type in ("tool_start", "tool_end", "tool_error", "turn_start", "agent_end"):
            tab = flyte.report.get_tab("Activity")
            tab.replace(ui.render_activity_log(events))
            await flyte.report.flush.aio()
        if ev.type == "tool_end" and ev.data.get("tool") in (
            "edit_train_code_batch",
            "<sandbox>",
        ):
            edits = await tools.load_saved_code_edits(memory_key)
            if edits:
                flyte.report.get_tab("Code edits").replace(ui.render_code_edits_panel(edits))
                await flyte.report.flush.aio()

    directive_text = ui.directive_code_edit_fanout(
        n_experiments,
        profile,
        memory_key,
        batch_size=batch_size,
        history=history,
    )

    token = agent_progress_cb.set(on_event)
    run_agent = build_fanout_agent(max_turns=max_turns)
    try:
        result = await run_agent.run.aio(directive_text, memory=memory)
    finally:
        agent_progress_cb.reset(token)

    leaderboard, best = ui.parse_leaderboard(
        memory.messages,
        promising_fallback=promising,
    )
    leaderboard_dicts = [dataclasses.asdict(e) for e in leaderboard]
    code_edits = await tools.load_saved_code_edits(memory_key)

    tab_lb = flyte.report.get_tab("Leaderboard")
    tab_lb.replace(ui.render_leaderboard(leaderboard, best))

    flyte.report.get_tab("Code edits").replace(
        ui.render_code_edits_panel(code_edits, best_title=best.title if best else None)
    )

    await memory.write_json.aio(
        "memory/leaderboard.json",
        leaderboard_dicts,
        actor="parallelized-autoresearch",
        reason=f"leaderboard after {len(leaderboard)} experiments",
    )
    await memory.save.aio()
    audit = await memory.audit_tail(20)
    hypotheses = await memory.read_json.aio("memory/hypotheses.json", default=[])
    promising = await memory.read_json.aio("memory/promising_code.json", default=[])

    tab_mem = flyte.report.get_tab("Memory")
    tab_mem.replace(
        ui.render_memory_panel(
            memory_key,
            len(memory.messages),
            leaderboard_dicts,
            audit,
            hypotheses,
            persisted_promising=promising,
            code_edits=code_edits,
        )
    )

    summary_body = result.summary or result.error or ""
    if result.error and leaderboard:
        best_line = f" Best val_bpb so far: {best.val_bpb} ({best.title})." if best and best.val_bpb else ""
        summary_body = f"{result.error}{best_line}"

    await flyte.report.replace.aio(
        ui.render_summary(
            directive_text,
            leaderboard,
            best,
            summary_body,
            code_edits=code_edits,
        )
    )
    await flyte.report.flush.aio()

    return AutoresearchOutput(
        directive=directive_text,
        dataset_profile=profile,
        best=best,
        leaderboard=leaderboard,
        summary=summary_body,
        memory_key=memory_key,
        total_experiments=len(leaderboard),
    )

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    import argparse
    import asyncio
    import os

    parser = argparse.ArgumentParser(description="Parallelized autoresearch agent (CODE MODE)")
    parser.add_argument("--n-experiments", type=int, default=6)
    parser.add_argument("--batch-size", type=int, default=3)
    parser.add_argument("--max-turns", type=int, default=DEFAULT_MAX_TURNS)
    parser.add_argument("--num-shards", type=int, default=DEFAULT_NUM_SHARDS)
    parser.add_argument("--memory-key", default=tools.MEMORY_KEY_FANOUT)
    parser.add_argument(
        "--config",
        default=os.environ.get("FLYTE_CONFIG", os.path.expanduser("~/.flyte/config.yaml")),
    )
    args = parser.parse_args()

    flyte.init_from_config(args.config, image_builder="remote")

    async def main() -> None:
        run = await flyte.with_runcontext(copy_style="all").run.aio(
            parallelized_autoresearch,
            n_experiments=args.n_experiments,
            num_shards=args.num_shards,
            memory_key=args.memory_key,
            batch_size=args.batch_size,
            max_turns=args.max_turns,
        )
        print(f"View run at: {run.url}")

    asyncio.run(main())
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/parallelized_autoresearch/parallelized_autoresearch.py*

Batch fan-out calls `flyte.map.aio(run_experiment, ...)` from `run_experiment_batch`. That path invokes `run_experiment.aio()` directly — **not** through the agent registry — so the example binds `call_llm` and `model` on the tool after construction (see the `dataclasses.replace` block above). With Flyte SDK ≥ 2.5.5, `AgentTool.aio` routes through `call_handler`, so every mapped experiment gets LLM right-sizing even when the agent only exposes `run_experiment_batch` in code mode.

## The fan-out agent task

The driver task `parallelized_autoresearch` restores prior memory (default key `parallelized-autoresearch`), streams Activity / Leaderboard / Code edits / Memory report tabs, and runs the code-mode agent loop. The agent tool registry is trimmed to the batch workflow — `run_experiment` is internal to `run_experiment_batch`, not a sandbox function the LLM calls directly.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.5.5",
#    "litellm",
#    "httpx",
#    "pydantic-monty",
#    "unionai-sandbox[flyte]",
#    "torch",
#    "numpy",
#    "pyarrow",
#    "requests",
#    "tiktoken",
#    "rustbpe",
# ]
# main = "parallelized_autoresearch"
# params = "--n-experiments 6 --batch-size 3 --num-shards 1"
# ///
"""Parallelized autoresearch agent — code-mode MLE agent with batched sandbox experiments."""

from __future__ import annotations

import dataclasses
from typing import Any

import flyte
import flyte.report
from flyte.ai.agents import Agent, MemoryStore, agent_progress_cb, tool

from autoresearch_types import AutoresearchOutput, DEFAULT_MAX_STEPS, DEFAULT_NUM_SHARDS, MAX_DEVICE_BATCH_SIZE, MAX_N_EMBD, MAX_N_HEAD, MAX_N_LAYER
from bundle import agent_env, build_bundle, experiment_env, materialize_cache, profile_bundle
import tools
import ui

MODEL = "claude-sonnet-4-6"

# {{docs-fragment run_experiment}}
@tool(call_handler=tools.right_size)
@experiment_env.task
async def run_experiment(
    title: str,
    time_budget_sec: int = 45,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
) -> dict:
    """Train using agent-edited ``train.py`` with LLM right-sizing and OOM self-healing."""
    train_py = await tools.load_train_code(memory_key, title)
    config_overrides = await tools.load_config_overrides(memory_key, title)
    duplicate = await tools.check_duplicate_config(memory_key, title, train_py, config_overrides)
    if duplicate:
        result = {
            "success": False,
            "title": title,
            "error": (
                f"Duplicate config of '{duplicate['duplicate_of']}' "
                f"(signature {duplicate['config_signature']}); change train.py or overrides."
            ),
            "duplicate_of": duplicate["duplicate_of"],
        }
        await tools.record_experiment_result(
            memory_key,
            result,
            actor="parallelized-autoresearch",
        )
        return result
    bundle = await build_bundle()
    cache_dir = await materialize_cache(bundle)
    result = await tools.run_train_in_sandbox(
        cache_dir,
        train_py,
        title=title,
        time_budget_sec=time_budget_sec,
        config_overrides=config_overrides or None,
    )
    if result.get("success"):
        await tools.record_promising_run(memory_key, title, result)
        await tools.register_config_signature(
            memory_key,
            title,
            train_py,
            config_overrides,
            actor="parallelized-autoresearch",
        )
    await tools.record_experiment_result(
        memory_key,
        result,
        actor="parallelized-autoresearch",
    )
    return result

# ``flyte.map`` invokes ``run_experiment.aio`` directly (not through the agent
# registry), so bind the LLM callback and model here for ``call_handler`` right-sizing.
run_experiment = dataclasses.replace(
    run_experiment,
    call_llm=tools.call_llm,
    model=MODEL,
)
# {{/docs-fragment run_experiment}}

@tool
@agent_env.task
async def run_experiment_batch(
    titles: list[str],
    time_budget_sec: int = 45,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
    concurrency: int = 4,
    batch_id: str = "",
) -> dict:
    """Run multiple ``run_experiment`` calls in parallel via ``flyte.map``.

    Prefer this over hand-rolling ``flyte_map`` when you already have a list of
    experiment titles with saved ``train.py`` edits.

    Args:
        titles: Experiment titles whose code was saved with ``edit_train_code_batch``.
        time_budget_sec: Wall-clock budget passed to each run.
        memory_key: Memory namespace from your directive.
        concurrency: Max parallel sandbox runs (default 4).
        batch_id: Optional label attached to the returned batch metadata.

    Returns:
        A dict with ``batch_size``, ``titles``, ``results``, and ``evaluation``
        (from :func:`evaluate_batch_results`).
    """
    group = batch_id or f"batch-{len(titles)}"
    payload = await tools.run_experiment_batch_impl(
        run_experiment,
        titles,
        time_budget_sec=time_budget_sec,
        memory_key=memory_key,
        concurrency=concurrency,
        group_name=group,
    )
    payload["evaluation"] = tools.evaluate_batch_results_impl(payload["results"], batch_id=batch_id)
    await tools.persist_run_results_to_leaderboard(memory_key, payload["results"])
    return payload

INSTRUCTIONS = f"""\
You are a senior ML-engineer agent running karpathy/autoresearch-style research by
**editing train.py** and **batching parallel experiments**. Your goal: MINIMIZE
val_bpb (LOWER is better).

You operate in CODE MODE. Each turn, write ONE ```python``` block that calls the
available functions, OR reply in plain text when finished. The last expression in
your code block is returned as the observation.

Core tools:
- get_code_edit_history — **call first on resumed sessions**: prior edits, val_bpb, vs-best deltas
- get_baseline_train_code, edit_train_code_batch, read_train_code, get_promising_code
- inspect_dataset, search_arxiv
- get_leaderboard, compare_experiments

Saving edits (required for visible diffs and distinct runs):
- **Batch 1 only:** you may use ``config_overrides`` for a quick architecture/LR sweep via
  ``edit_train_code_batch(edits=[{{"title": "...", "config_overrides": {{"n_layer": 6}}, "change_summary": "..."}}])``.
- **Batch 2 and later:** every edit must include a **substantive ``train_py`` change**
  (learning-rate schedule, optimizer/weight_decay, grad clipping, warmup, etc.).
  ``config_overrides`` alone is **rejected** after the first batch — fork with
  ``parent_title`` and edit the training loop in ``train_py``.
- ``config_overrides`` fields: ``n_layer``, ``n_head``, ``n_embd``, ``dropout``,
  ``device_batch_size``, ``learning_rate``, ``time_budget_sec``, ``max_steps``.
- To fork a winner: set ``parent_title`` to the best title, then edit ``train_py``.
- Do **not** save baseline ``train.py`` without overrides — the platform rejects identical edits.
- Duplicate configs (same effective train.py + overrides) are rejected at run time.

Training budget (fair comparison across architectures):
- Default **max_steps={DEFAULT_MAX_STEPS}** with **time_budget_sec=45** as a safety cap.
  All models train for the same step count unless they hit the wall-clock limit.
- Check ``steps`` in batch results — if a run stopped early on time, the model may be too large.

Batch / fan-out tools:
- record_batch_plan(batch_id, experiments) — persist a multi-experiment plan
- get_batch_plan(batch_id) — reload a plan
- record_batch_hypotheses(experiments) — write hypotheses for every title in a batch
- edit_train_code_batch(edits) — save all ``train.py`` edits in one memory transaction
- run_experiment_batch(titles, concurrency=...) — parallel sandbox runs (LLM right-sized; OOM-healed)
- evaluate_batch_results(results, batch_id=...) — rank successes vs failures

Typical batch loop (aim for **≤8 code turns** before your plain-text summary):
0. If prior research exists in your directive, ``get_code_edit_history()`` then
   ``read_train_code(best_title)`` before planning new batches.
1. Turn 1: ``get_baseline_train_code()`` + ``inspect_dataset()``.
2. Turn 2: ``record_batch_plan`` then ``edit_train_code_batch(edits=[...])`` for the whole batch.
3. Turn 3: ``record_batch_hypotheses`` + ``run_experiment_batch(titles, concurrency=...)``.
4. Turn 4+: fork winners into the next batch with **train.py** edits, or reply in plain text when done.

Batch diversity (required):
- Every title in a batch must test a **distinct hypothesis** — no duplicate configs or renames.
- **Spread axes across the batch**: e.g. one edit tweaks depth/width, another changes the
  **training loop** (cosine LR, AdamW betas, weight decay), another regularization or batch size.
- Avoid LR micro-sweeps (±30% of the current best LR) after batch 1 — those rarely beat a plateau.
- Vary **one or two knobs per edit**; state the change in ``change_summary`` and
  ``record_batch_hypotheses``.
- Use ``evaluate_batch_results`` to see **which axis** helped, then explore under-tested axes.

Plateau rule (required):
- If **3 consecutive batches** fail to beat the global best val_bpb by more than **0.01**,
  stop hyperparameter micro-sweeps. Switch to **training-loop code edits** in ``train.py``
  (scheduler, optimizer, regularization, data/loss changes).

Rules:
- Use ``edit_train_code_batch`` for all code saves (including a single title: ``edits=[{{...}}]``).
- Every edit must keep ``run_training(config: ExperimentConfig) -> ExperimentResult``.
- Do NOT size compute — each run is LLM right-sized and retried automatically on OOM.
- Workshop limits: n_layer<={MAX_N_LAYER}, n_embd<={MAX_N_EMBD}, n_head<={MAX_N_HEAD},
  device_batch_size<={MAX_DEVICE_BATCH_SIZE}, seq_len=512.
- Monty sandbox: no imports, no dict mutation, no augmented assignment (`+=`).
- **Always finish with plain text (no code block)** once you have results to report.
"""

DEFAULT_MAX_TURNS = 50

def build_fanout_agent(*, max_turns: int = DEFAULT_MAX_TURNS) -> Agent:
    """Construct the fan-out agent (``code_mode=True``) with a configurable turn budget."""
    return Agent(
        name="parallelized-autoresearch",
        instructions=INSTRUCTIONS,
        model=MODEL,
        tools=[
            tools.search_arxiv,
            tools.inspect_dataset,
            tools.get_baseline_train_code,
            tools.get_code_edit_history,
            tools.edit_train_code_batch,
            tools.read_train_code,
            tools.get_promising_code,
            tools.get_leaderboard,
            tools.compare_experiments,
            tools.record_batch_plan,
            tools.get_batch_plan,
            tools.record_batch_hypotheses,
            run_experiment_batch,
            tools.evaluate_batch_results,
        ],
        max_turns=max_turns,
        call_llm=tools.call_llm,
        code_mode=True,
    )

# {{docs-fragment agent}}
@agent_env.task(report=True)
async def parallelized_autoresearch(
    n_experiments: int = 6,
    num_shards: int = DEFAULT_NUM_SHARDS,
    memory_key: str = tools.MEMORY_KEY_FANOUT,
    batch_size: int = 3,
    max_turns: int = DEFAULT_MAX_TURNS,
) -> AutoresearchOutput:
    """Drive the fan-out code-edit MLE agent with sandbox batch execution."""
    bundle = await build_bundle(num_shards=num_shards)
    profile = await profile_bundle(bundle)

    memory = await MemoryStore.get_or_create.aio(key=memory_key)
    persisted = await memory.read_json.aio("memory/leaderboard.json", default=[])
    promising = await memory.read_json.aio("memory/promising_code.json", default=[])
    history = await tools.load_research_history(memory_key)
    flyte.logger.info(
        "Fan-out agent restored %d messages, %d experiments, %d promising edits, best val_bpb=%s.",
        len(memory.messages),
        len(persisted),
        len(promising),
        history.get("best_val_bpb"),
    )

    events: list[dict[str, Any]] = []

    async def on_event(ev) -> None:
        events.append({"type": ev.type, "data": ev.data})
        if ev.type in ("tool_start", "tool_end", "tool_error", "turn_start", "agent_end"):
            tab = flyte.report.get_tab("Activity")
            tab.replace(ui.render_activity_log(events))
            await flyte.report.flush.aio()
        if ev.type == "tool_end" and ev.data.get("tool") in (
            "edit_train_code_batch",
            "<sandbox>",
        ):
            edits = await tools.load_saved_code_edits(memory_key)
            if edits:
                flyte.report.get_tab("Code edits").replace(ui.render_code_edits_panel(edits))
                await flyte.report.flush.aio()

    directive_text = ui.directive_code_edit_fanout(
        n_experiments,
        profile,
        memory_key,
        batch_size=batch_size,
        history=history,
    )

    token = agent_progress_cb.set(on_event)
    run_agent = build_fanout_agent(max_turns=max_turns)
    try:
        result = await run_agent.run.aio(directive_text, memory=memory)
    finally:
        agent_progress_cb.reset(token)

    leaderboard, best = ui.parse_leaderboard(
        memory.messages,
        promising_fallback=promising,
    )
    leaderboard_dicts = [dataclasses.asdict(e) for e in leaderboard]
    code_edits = await tools.load_saved_code_edits(memory_key)

    tab_lb = flyte.report.get_tab("Leaderboard")
    tab_lb.replace(ui.render_leaderboard(leaderboard, best))

    flyte.report.get_tab("Code edits").replace(
        ui.render_code_edits_panel(code_edits, best_title=best.title if best else None)
    )

    await memory.write_json.aio(
        "memory/leaderboard.json",
        leaderboard_dicts,
        actor="parallelized-autoresearch",
        reason=f"leaderboard after {len(leaderboard)} experiments",
    )
    await memory.save.aio()
    audit = await memory.audit_tail(20)
    hypotheses = await memory.read_json.aio("memory/hypotheses.json", default=[])
    promising = await memory.read_json.aio("memory/promising_code.json", default=[])

    tab_mem = flyte.report.get_tab("Memory")
    tab_mem.replace(
        ui.render_memory_panel(
            memory_key,
            len(memory.messages),
            leaderboard_dicts,
            audit,
            hypotheses,
            persisted_promising=promising,
            code_edits=code_edits,
        )
    )

    summary_body = result.summary or result.error or ""
    if result.error and leaderboard:
        best_line = f" Best val_bpb so far: {best.val_bpb} ({best.title})." if best and best.val_bpb else ""
        summary_body = f"{result.error}{best_line}"

    await flyte.report.replace.aio(
        ui.render_summary(
            directive_text,
            leaderboard,
            best,
            summary_body,
            code_edits=code_edits,
        )
    )
    await flyte.report.flush.aio()

    return AutoresearchOutput(
        directive=directive_text,
        dataset_profile=profile,
        best=best,
        leaderboard=leaderboard,
        summary=summary_body,
        memory_key=memory_key,
        total_experiments=len(leaderboard),
    )

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    import argparse
    import asyncio
    import os

    parser = argparse.ArgumentParser(description="Parallelized autoresearch agent (CODE MODE)")
    parser.add_argument("--n-experiments", type=int, default=6)
    parser.add_argument("--batch-size", type=int, default=3)
    parser.add_argument("--max-turns", type=int, default=DEFAULT_MAX_TURNS)
    parser.add_argument("--num-shards", type=int, default=DEFAULT_NUM_SHARDS)
    parser.add_argument("--memory-key", default=tools.MEMORY_KEY_FANOUT)
    parser.add_argument(
        "--config",
        default=os.environ.get("FLYTE_CONFIG", os.path.expanduser("~/.flyte/config.yaml")),
    )
    args = parser.parse_args()

    flyte.init_from_config(args.config, image_builder="remote")

    async def main() -> None:
        run = await flyte.with_runcontext(copy_style="all").run.aio(
            parallelized_autoresearch,
            n_experiments=args.n_experiments,
            num_shards=args.num_shards,
            memory_key=args.memory_key,
            batch_size=args.batch_size,
            max_turns=args.max_turns,
        )
        print(f"View run at: {run.url}")

    asyncio.run(main())
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/parallelized_autoresearch/parallelized_autoresearch.py*

## Run the agent

### Create secrets

Register an Anthropic API key for agent LLM calls and for per-experiment resource sizing inside `call_handler`:

```
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

### Run remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/parallelized_autoresearch):

```
cd v2/tutorials/parallelized_autoresearch
uv run --script parallelized_autoresearch.py --n-experiments 6 --batch-size 3 --num-shards 1
```

Use `--memory-key` to resume a prior research session (default: `parallelized-autoresearch`). Pass a unique key — for example `parallelized-autoresearch-20260622-215057` — to start with empty memory. Code mode needs more turns than JSON tool mode — increase `--max-turns` for larger sweeps.

Or invoke the agent task directly with `flyte run` (snake_case task inputs):

```
flyte run parallelized_autoresearch.py parallelized_autoresearch \
  --n_experiments 6 --batch_size 3 --num_shards 1 --max_turns 12 \
  --memory_key parallelized-autoresearch
```

> [!NOTE]
> The first run downloads climbmix data shards and trains a BPE tokenizer. Subsequent runs reuse cached bundle tasks. Requires **Flyte SDK ≥ 2.5.5** for `call_handler` support in code mode and on `AgentTool.aio` (used by `flyte.map` fan-out).

See also the single-task [Autoresearch agent](../autoresearch/_index) tutorial for the Claude Code + pull-request workflow.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/autosec-research-agent ===

# AutoSec researcher agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/autosec_research_agent).

This tutorial demonstrates an autonomous security-research agent on Flyte. The pipeline fans out across bundled C source files (each with a planted memory-corruption bug), runs static analysis, uses a `flyte.ai.agents.Agent` to hypothesize vulnerabilities, builds proof-of-concept payloads, and validates exploits inside an on-device [unionai-sandbox](https://www.union.ai/docs/v2/union/user-guide/sandboxing/_index) user-namespace session.

Flyte provides:

- **Parallel fan-out** across every target file with `asyncio.gather`
- **Self-healing tasks** — LLM timeouts, malformed JSON, and OOM during static analysis retry with bounded resources
- **Sandbox isolation** — PoC compilation and execution never runs on the orchestration node
- **Live HTML reports** with per-target detail tabs in the Flyte UI

> [!WARNING]
> This example analyzes deliberately vulnerable C code and runs generated exploit payloads in a sandbox. Use it only in controlled environments.

## Define the task environment

The agent needs an Anthropic API key and a container image with `gcc` for sandbox compilation.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "unionai-sandbox",
#    "litellm",
# ]
# main = "run_autosec_agent"
# params = ""
# ///
"""AutoSec researcher agent — parallel vulnerability analysis with sandbox PoC validation."""

from __future__ import annotations

import asyncio
import html
import json
import os
import pathlib
import re
from typing import Any

import flyte
import flyte.errors
import flyte.report
from flyte.ai.agents import Agent

HERE = pathlib.Path(__file__).parent
TARGETS_DIR = HERE / "targets"
MODEL = os.getenv("AUTOSEC_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="autosec-research-agent", pre=True).with_apt_packages("gcc")

env = flyte.TaskEnvironment(
    name="autosec-research-agent",
    image=main_img,
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    include=[str(TARGETS_DIR)],
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

def _attempt() -> int:
    tc = flyte.ctx()
    return tc.attempt_number if tc is not None else 0

def _force(flag: str) -> bool:
    return bool(os.getenv(flag) or os.getenv("AUTOSEC_FORCE_ALL"))

def _extract_json(text: str) -> dict[str, Any]:
    match = re.search(r"\{.*\}", text, re.DOTALL)
    if not match:
        raise ValueError(f"no JSON object in model reply: {text[:200]!r}")
    blob = match.group(0)
    try:
        return json.loads(blob)
    except json.JSONDecodeError:
        fixed = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", blob)
        return json.loads(fixed)

# --- Stage 1: static analysis (CPU, OOM-prone) ------------------------------
@env.task(retries=2, timeout=30)
async def scan_static(source: str, scope: str = "whole") -> str:
    """Cheap stand-in for whole-program analysis (Joern/CodeQL in the real system)."""
    try:
        if scope == "whole" and _force("AUTOSEC_FORCE_OOM") and _attempt() == 0:
            raise flyte.errors.OOMError("whole-program graph exceeded memory limit")
        findings = _grep_dangerous_calls(source)
        return findings or "(no dangerous-call sites found)"
    except flyte.errors.OOMError as exc:
        print(f"[scan_static] {exc}; escalating resources + narrowing scope")
        return await scan_static.override(
            short_name="scan_static_more_resources", resources=flyte.Resources(cpu=2, memory="4Gi")
        )(source, scope="file")

def _grep_dangerous_calls(source: str) -> str:
    hits = []
    for i, line in enumerate(source.splitlines(), start=1):
        for fn in ("strcpy", "strcat", "sprintf", "gets", "memcpy"):
            if fn in line:
                hits.append(f"L{i}: {fn} -> {line.strip()}")
    return "\n".join(hits)

# --- Stage 2: hypothesize the vulnerability (LLM via Agent) ------------------
ANALYSIS_INSTRUCTIONS = """\
You are a vulnerability researcher. Your job is to determine whether a given \
C source file contains an exploitable memory-corruption bug reachable from argv.

You have access to these tools during your analysis:
- scan_static: Run static analysis on the source to find dangerous function calls.
- build_poc: Build a proof-of-concept payload (do not call during analysis).
- validate_in_sandbox: Compile and run the target with a PoC input (do not call during analysis).

Focus on analyzing the source and the provided static analysis findings. Call \
scan_static only if you need additional details about dangerous function usage.

Reply with ONLY a JSON object (no prose, no markdown fences):
If vulnerable: {"vulnerable": true, "function": str, \
"buffer_size": int (bytes of the overflowable buffer), "vuln_class": str, \
"reasoning": str}.
If the code looks safe (bounded copies, length checks, snprintf/strlcpy, \
etc.): {"vulnerable": false, "reasoning": str}.
"""

# --- Stage 3: build a proof-of-concept --------------------------------------
@env.task(retries=2, timeout=90)
async def build_poc(hypothesis: dict) -> dict:
    buffer_size = int(hypothesis.get("buffer_size", 64))
    payload_len = buffer_size + 64
    return {
        "payload_len": payload_len,
        "payload_repr": f'"A" * {payload_len}',
        "target_function": hypothesis.get("function", "greet"),
    }

# --- Stage 4: validate in an on-device sandbox -------------------------------
@env.task(retries=2, timeout=300)
async def validate_in_sandbox(source: str, poc: dict) -> dict:
    """Compile + run the target with the PoC input inside an on-device sandbox.

    The exploit code runs in a user-namespace sandbox on the same machine, never
    on the Flyte orchestration node (SPEC §2.6 / §7). The session is torn down
    in __aexit__ regardless of outcome (SPEC VD-5) so a stuck or failed run
    cannot leak resources.
    """
    import tempfile

    from union import sandbox as sb

    with tempfile.TemporaryDirectory() as work:
        async with sb.on_device.session(host_work_dir=work, backend="userns") as sbx:
            await sbx.put_bytes(f"{work}/target.c", source.encode())

            compile_proc = await sbx.run(
                f"gcc -fno-stack-protector -w -o {work}/target {work}/target.c",
                stdout=True,
                stderr=True,
                timeout_s=60,
            )
            compile_out, compile_err = await compile_proc.communicate_text()
            log = compile_out + compile_err
            if "error" in log.lower():
                return {
                    "triggered": False,
                    "sandbox_exit_code": -1,
                    "log": f"COMPILE_FAILED\n{log}",
                }

            payload = "A" * int(poc["payload_len"])
            run_proc = await sbx.run(
                f"{work}/target {payload}",
                stdout=True,
                stderr=True,
                timeout_s=60,
            )
            run_out, run_err = await run_proc.communicate_text()
            log = run_out + "\n" + run_err
            triggered = "SIGSEGV" in log

            return {
                "triggered": bool(triggered),
                "sandbox_exit_code": getattr(run_proc, "returncode", 0),
                "log": log,
            }

# --- Agent + hypothesize task (depends on all tools above) ------------------
hypothesis_agent = Agent(
    name="autosec-hypothesis",
    instructions=ANALYSIS_INSTRUCTIONS,
    model=MODEL,
    tools=[scan_static, build_poc, validate_in_sandbox],
    max_turns=6,
)

@env.task(retries=3, timeout=20)
async def hypothesize(source: str, static_findings: str) -> dict:
    prompt = (
        "Analyze this C source file for memory-corruption vulnerabilities.\n\n"
        f"SOURCE:\n{source}\n\nDANGEROUS CALLS:\n{static_findings}\n"
    )

    # Beat A: hang on the first attempt -> task timeout -> retry.
    timeout_on = _force("AUTOSEC_FORCE_LLM_TIMEOUT") and _attempt() == 0
    bad_on = _force("AUTOSEC_FORCE_BAD_TOOL_CALL")

    if timeout_on:
        await asyncio.sleep(600)

    result = await hypothesis_agent.run.aio(prompt, memory=[])
    raw = result.summary or ""

    # Beat B: simulate a hallucinated/malformed tool call. When the timeout beat
    # is also active it consumes attempt 0, so defer this to attempt 1 — that way
    # both beats are actually demonstrated in a single run (e.g. AUTOSEC_FORCE_ALL).
    bad_attempt = 1 if _force("AUTOSEC_FORCE_LLM_TIMEOUT") else 0
    if bad_on and _attempt() == bad_attempt:
        raw = "Sure! The bug is somewhere around here, trust me."

    hyp = _extract_json(raw)
    if "vulnerable" not in hyp:
        hyp["vulnerable"] = "buffer_size" in hyp
    if hyp.get("vulnerable") and "buffer_size" not in hyp:
        raise ValueError(f"vulnerable hypothesis missing buffer_size: {hyp}")
    return hyp

# --- Orchestration ----------------------------------------------------------
def _load_targets() -> dict[str, str]:
    return {p.name: p.read_text() for p in sorted(TARGETS_DIR.glob("*.c"))}

@env.task
async def analyze_target(name: str, source: str) -> dict:
    findings = await scan_static(source)
    hypothesis = await hypothesize(source, findings)

    if not hypothesis.get("vulnerable"):
        poc: dict = {}
        verdict = {"triggered": False, "skipped": True}
    else:
        poc = await build_poc(hypothesis)
        verdict = await validate_in_sandbox(source, poc)

    return {
        "target": name,
        "static_findings": findings,
        "hypothesis": hypothesis,
        "poc": poc,
        "verdict": verdict,
    }

_REPORT_CSS = """
<style>
  .autosec { --bg:#ffffff; --card:#f7f8fa; --line:#e3e7ec; --muted:#5b6675;
    --text:#1b2330; --red:#c0392b; --amber:#b6791f; --green:#1e7e34; --accent:#1f6feb;
    font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;
    background:var(--bg); color:var(--text); padding:24px; border-radius:12px;
    border:1px solid var(--line); }
  .autosec h2 { margin:0 0 4px; font-size:20px; letter-spacing:.2px; }
  .autosec .sub { color:var(--muted); font-size:13px; margin:0 0 20px; }
  .autosec .cards { display:flex; gap:12px; flex-wrap:wrap; margin-bottom:22px; }
  .autosec .card { background:var(--card); border:1px solid var(--line);
    border-radius:10px; padding:14px 18px; min-width:120px; }
  .autosec .card .n { font-size:26px; font-weight:700; line-height:1; }
  .autosec .card .l { color:var(--muted); font-size:12px; margin-top:6px;
    text-transform:uppercase; letter-spacing:.6px; }
  .autosec table { width:100%; table-layout:fixed; border-collapse:collapse; font-size:13px;
    background:#fff; border:1px solid var(--line); border-radius:10px; overflow:hidden; }
  .autosec th, .autosec td { overflow-wrap:anywhere; }
  .autosec thead th { background:#eef1f5; color:var(--muted); text-align:left;
    font-weight:600; font-size:11px; text-transform:uppercase; letter-spacing:.6px;
    padding:11px 14px; border-bottom:1px solid var(--line); }
  .autosec tbody td { padding:11px 14px; border-bottom:1px solid var(--line);
    vertical-align:top; }
  .autosec tbody tr:last-child td { border-bottom:none; }
  .autosec tbody tr:nth-child(even) { background:#fafbfc; }
  .autosec tbody tr:hover { background:#eef4ff; }
  .autosec code { background:#eef1f5; border:1px solid var(--line); border-radius:5px;
    padding:1px 6px; font-family:ui-monospace,SFMono-Regular,Menlo,monospace; font-size:12px; }
  .autosec .num { text-align:right; font-variant-numeric:tabular-nums; }
  .autosec .reason { color:var(--muted); line-height:1.45; }
  .autosec .badge { display:inline-block; padding:3px 10px; border-radius:999px;
    font-size:11px; font-weight:700; letter-spacing:.4px; white-space:nowrap; }
  .autosec .b-exploited { background:rgba(192,57,43,.10); color:var(--red);
    border:1px solid rgba(192,57,43,.35); }
  .autosec .b-vuln { background:rgba(182,121,31,.12); color:var(--amber);
    border:1px solid rgba(182,121,31,.35); }
  .autosec .b-secure { background:rgba(30,126,52,.10); color:var(--green);
    border:1px solid rgba(30,126,52,.35); }
  .autosec col.c-target  { width:15%; }
  .autosec col.c-status  { width:11%; }
  .autosec col.c-class   { width:11%; }
  .autosec col.c-fn      { width:11%; }
  .autosec col.c-buf     { width:8%; }
  .autosec col.c-payload { width:8%; }
  .autosec col.c-exit    { width:6%; }
  .autosec col.c-reason  { width:30%; }
  .autosec .kv { display:flex; flex-wrap:wrap; gap:12px 28px; margin:16px 0 4px; }
  .autosec .kv .k { color:var(--muted); font-size:11px; text-transform:uppercase; letter-spacing:.6px; }
  .autosec .kv .v { font-weight:600; font-size:14px; margin-top:3px; }
  .autosec .section-label { font-size:11px; text-transform:uppercase; letter-spacing:.6px;
    color:var(--muted); margin:22px 0 7px; }
  .autosec .reason-block { background:var(--card); border:1px solid var(--line);
    border-radius:8px; padding:13px 15px; line-height:1.5; font-size:13px; }
  .autosec pre.code { background:#f6f8fa; border:1px solid var(--line); border-radius:8px;
    padding:14px 16px; overflow:auto; font-family:ui-monospace,SFMono-Regular,Menlo,monospace;
    font-size:12.5px; line-height:1.5; color:#1b2330; margin:0; }
  .autosec .subtabs > input[type=radio] { position:absolute; opacity:0; pointer-events:none; }
  .autosec .subnav { display:flex; flex-wrap:wrap; gap:4px; border-bottom:1px solid var(--line);
    margin:8px 0 18px; }
  .autosec .subnav label { display:inline-flex; align-items:center; gap:8px; padding:8px 14px;
    font-size:12.5px; cursor:pointer; border:1px solid transparent; border-bottom:none;
    border-radius:8px 8px 0 0; color:var(--muted); margin-bottom:-1px; }
  .autosec .subnav label:hover { background:var(--card); color:var(--text); }
  .autosec .subnav .dot { width:8px; height:8px; border-radius:50%; }
  .autosec .dot.b-exploited { background:var(--red); }
  .autosec .dot.b-vuln { background:var(--amber); }
  .autosec .dot.b-secure { background:var(--green); }
  .autosec .panel { display:none; }
</style>
"""

def _status(finding: dict) -> tuple[str, str]:
    hyp = finding.get("hypothesis") or {}
    verdict = finding.get("verdict") or {}
    if not hyp.get("vulnerable"):
        return "b-secure", "SECURE"
    if verdict.get("triggered"):
        return "b-exploited", "EXPLOITED"
    return "b-vuln", "VULNERABLE"

def _render_report_html(findings: list[dict]) -> str:
    exploited = sum(1 for f in findings if (f.get("verdict") or {}).get("triggered"))
    vulnerable = sum(1 for f in findings if (f.get("hypothesis") or {}).get("vulnerable"))
    secure = len(findings) - vulnerable

    rows = []
    for f in sorted(findings, key=lambda x: x["target"]):
        hyp = f.get("hypothesis") or {}
        verdict = f.get("verdict") or {}
        cls, label = _status(f)
        is_vuln = bool(hyp.get("vulnerable"))
        vuln_class = hyp.get("vuln_class", "\u2014") if is_vuln else "\u2014"
        fn = hyp.get("function", "\u2014") if is_vuln else "\u2014"
        buf = hyp.get("buffer_size", "\u2014") if is_vuln else "\u2014"
        payload = (f.get("poc") or {}).get("payload_len", "\u2014") if is_vuln else "\u2014"
        exit_code = verdict.get("sandbox_exit_code", "\u2014")
        rows.append(
            "<tr>"
            f"<td><code>{html.escape(str(f['target']))}</code></td>"
            f'<td><span class="badge {cls}">{label}</span></td>'
            f"<td>{html.escape(str(vuln_class))}</td>"
            f"<td><code>{html.escape(str(fn))}</code></td>"
            f'<td class="num">{html.escape(str(buf))}</td>'
            f'<td class="num">{html.escape(str(payload))}</td>'
            f'<td class="num">{html.escape(str(exit_code))}</td>'
            f'<td class="reason">{html.escape(str(hyp.get("reasoning", "")))}</td>'
            "</tr>"
        )

    return f"""{_REPORT_CSS}
    <div class="autosec">
      <h2>AutoSec &middot; security findings report</h2>
      <p class="sub">{len(findings)} target(s) analyzed in parallel &middot; PoCs validated in isolated sandbox.</p>
      <div class="cards">
        <div class="card"><div class="n">{len(findings)}</div><div class="l">Targets</div></div>
        <div class="card"><div class="n" style="color:#ff6b6b">{exploited}</div><div class="l">Exploited</div></div>
        <div class="card">
          <div class="n" style="color:#ffb454">{vulnerable - exploited}</div>
          <div class="l">Vuln, PoC failed</div>
        </div>
        <div class="card"><div class="n" style="color:#3fb950">{secure}</div><div class="l">Secure</div></div>
      </div>
      <table>
        <colgroup>
          <col class="c-target"><col class="c-status"><col class="c-class"><col class="c-fn">
          <col class="c-buf"><col class="c-payload"><col class="c-exit"><col class="c-reason">
        </colgroup>
        <thead>
          <tr>
            <th>Target</th><th>Status</th><th>Vuln class</th><th>Function</th>
            <th class="num">Buffer&nbsp;(B)</th><th class="num">Payload&nbsp;(B)</th>
            <th class="num">Exit</th><th>Reasoning</th>
          </tr>
        </thead>
        <tbody>{"".join(rows)}</tbody>
      </table>
    </div>
    """

def _target_detail_html(finding: dict, source: str) -> str:
    hyp = finding.get("hypothesis") or {}
    verdict = finding.get("verdict") or {}
    cls, label = _status(finding)
    is_vuln = bool(hyp.get("vulnerable"))

    def cell(k: str, v: Any) -> str:
        return f'<div><div class="k">{html.escape(k)}</div><div class="v">{html.escape(str(v))}</div></div>'

    triggered = verdict.get("triggered")
    verdict_txt = "skipped (secure)" if verdict.get("skipped") else ("triggered" if triggered else "not triggered")
    stats = "".join(
        [
            cell("Vuln class", hyp.get("vuln_class", "\u2014") if is_vuln else "\u2014"),
            cell("Function", hyp.get("function", "\u2014") if is_vuln else "\u2014"),
            cell("Buffer (B)", hyp.get("buffer_size", "\u2014") if is_vuln else "\u2014"),
            cell("Payload (B)", (finding.get("poc") or {}).get("payload_len", "\u2014") if is_vuln else "\u2014"),
            cell("Sandbox exit", verdict.get("sandbox_exit_code", "\u2014")),
            cell("PoC", verdict_txt),
        ]
    )
    reasoning = html.escape(str(hyp.get("reasoning", "")) or "\u2014")
    code = html.escape(source or "(source unavailable)")
    return f"""
      <h3 style="margin:0 0 4px"><code>{html.escape(str(finding["target"]))}</code>
          &nbsp;<span class="badge {cls}">{label}</span></h3>
      <p class="sub">Per-target detail &middot; PoCs validated in an isolated sandbox.</p>
      <div class="kv">{stats}</div>
      <div class="section-label">Reasoning</div>
      <div class="reason-block">{reasoning}</div>
      <div class="section-label">Source</div>
      <pre class="code">{code}</pre>
    """

def _render_targets_tab_html(findings: list[dict], sources: dict[str, str]) -> str:
    ordered = sorted(findings, key=lambda x: x["target"])

    radios, nav, panels, rules = [], [], [], []
    for i, f in enumerate(ordered):
        name = f["target"]
        cls, _ = _status(f)
        checked = " checked" if i == 0 else ""
        radios.append(f'<input type="radio" name="as-targets" id="as-t{i}"{checked}>')
        nav.append(f'<label for="as-t{i}"><span class="dot {cls}"></span><code>{html.escape(str(name))}</code></label>')
        panels.append(f'<div class="panel" id="as-p{i}">{_target_detail_html(f, sources.get(name, ""))}</div>')
        rules.append(
            f'.autosec #as-t{i}:checked ~ .subnav label[for="as-t{i}"]'
            "{background:#fff;color:var(--text);border-color:var(--line);font-weight:600;}"
            f".autosec #as-t{i}:checked ~ .panels #as-p{i}{{display:block;}}"
        )

    return f"""{_REPORT_CSS}
    <style>{"".join(rules)}</style>
    <div class="autosec">
      <h2>AutoSec &middot; target detail</h2>
      <p class="sub">{len(ordered)} target(s) &middot; select a file to see its status, reasoning, and source.</p>
      <div class="subtabs">
        {"".join(radios)}
        <div class="subnav">{"".join(nav)}</div>
        <div class="panels">{"".join(panels)}</div>
      </div>
    </div>
    """

@env.task(retries=1)
async def random_error() -> str:
    if _attempt() == 0:
        raise Exception("Random error")
    return "Passed!"

# {{docs-fragment pipeline}}
@env.task(report=True)
async def run_autosec_agent() -> dict:
    targets = _load_targets()
    if not targets:
        raise FileNotFoundError(f"no targets found under {TARGETS_DIR}")

    findings = list(await asyncio.gather(*(analyze_target(name, src) for name, src in targets.items())))

    await flyte.report.replace.aio(_render_report_html(findings))
    flyte.report.get_tab("targets").replace(_render_targets_tab_html(findings, targets))
    await flyte.report.flush.aio()

    await random_error()

    return {
        "targets_analyzed": len(findings),
        "triggered": sum(1 for f in findings if f["verdict"].get("triggered")),
        "findings": findings,
    }
# {{/docs-fragment pipeline}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(run_autosec_agent)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autosec_research_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "unionai-sandbox",
#    "litellm",
# ]
# ///
```

## Run the security pipeline

Each target flows through four stages: static scan, LLM hypothesis, PoC construction, and sandbox validation. The `run_autosec_agent` driver task analyzes all bundled targets in parallel and streams a findings report.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "unionai-sandbox",
#    "litellm",
# ]
# main = "run_autosec_agent"
# params = ""
# ///
"""AutoSec researcher agent — parallel vulnerability analysis with sandbox PoC validation."""

from __future__ import annotations

import asyncio
import html
import json
import os
import pathlib
import re
from typing import Any

import flyte
import flyte.errors
import flyte.report
from flyte.ai.agents import Agent

HERE = pathlib.Path(__file__).parent
TARGETS_DIR = HERE / "targets"
MODEL = os.getenv("AUTOSEC_MODEL", "claude-haiku-4-5")

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="autosec-research-agent", pre=True).with_apt_packages("gcc")

env = flyte.TaskEnvironment(
    name="autosec-research-agent",
    image=main_img,
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    include=[str(TARGETS_DIR)],
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
)
# {{/docs-fragment env}}

def _attempt() -> int:
    tc = flyte.ctx()
    return tc.attempt_number if tc is not None else 0

def _force(flag: str) -> bool:
    return bool(os.getenv(flag) or os.getenv("AUTOSEC_FORCE_ALL"))

def _extract_json(text: str) -> dict[str, Any]:
    match = re.search(r"\{.*\}", text, re.DOTALL)
    if not match:
        raise ValueError(f"no JSON object in model reply: {text[:200]!r}")
    blob = match.group(0)
    try:
        return json.loads(blob)
    except json.JSONDecodeError:
        fixed = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", blob)
        return json.loads(fixed)

# --- Stage 1: static analysis (CPU, OOM-prone) ------------------------------
@env.task(retries=2, timeout=30)
async def scan_static(source: str, scope: str = "whole") -> str:
    """Cheap stand-in for whole-program analysis (Joern/CodeQL in the real system)."""
    try:
        if scope == "whole" and _force("AUTOSEC_FORCE_OOM") and _attempt() == 0:
            raise flyte.errors.OOMError("whole-program graph exceeded memory limit")
        findings = _grep_dangerous_calls(source)
        return findings or "(no dangerous-call sites found)"
    except flyte.errors.OOMError as exc:
        print(f"[scan_static] {exc}; escalating resources + narrowing scope")
        return await scan_static.override(
            short_name="scan_static_more_resources", resources=flyte.Resources(cpu=2, memory="4Gi")
        )(source, scope="file")

def _grep_dangerous_calls(source: str) -> str:
    hits = []
    for i, line in enumerate(source.splitlines(), start=1):
        for fn in ("strcpy", "strcat", "sprintf", "gets", "memcpy"):
            if fn in line:
                hits.append(f"L{i}: {fn} -> {line.strip()}")
    return "\n".join(hits)

# --- Stage 2: hypothesize the vulnerability (LLM via Agent) ------------------
ANALYSIS_INSTRUCTIONS = """\
You are a vulnerability researcher. Your job is to determine whether a given \
C source file contains an exploitable memory-corruption bug reachable from argv.

You have access to these tools during your analysis:
- scan_static: Run static analysis on the source to find dangerous function calls.
- build_poc: Build a proof-of-concept payload (do not call during analysis).
- validate_in_sandbox: Compile and run the target with a PoC input (do not call during analysis).

Focus on analyzing the source and the provided static analysis findings. Call \
scan_static only if you need additional details about dangerous function usage.

Reply with ONLY a JSON object (no prose, no markdown fences):
If vulnerable: {"vulnerable": true, "function": str, \
"buffer_size": int (bytes of the overflowable buffer), "vuln_class": str, \
"reasoning": str}.
If the code looks safe (bounded copies, length checks, snprintf/strlcpy, \
etc.): {"vulnerable": false, "reasoning": str}.
"""

# --- Stage 3: build a proof-of-concept --------------------------------------
@env.task(retries=2, timeout=90)
async def build_poc(hypothesis: dict) -> dict:
    buffer_size = int(hypothesis.get("buffer_size", 64))
    payload_len = buffer_size + 64
    return {
        "payload_len": payload_len,
        "payload_repr": f'"A" * {payload_len}',
        "target_function": hypothesis.get("function", "greet"),
    }

# --- Stage 4: validate in an on-device sandbox -------------------------------
@env.task(retries=2, timeout=300)
async def validate_in_sandbox(source: str, poc: dict) -> dict:
    """Compile + run the target with the PoC input inside an on-device sandbox.

    The exploit code runs in a user-namespace sandbox on the same machine, never
    on the Flyte orchestration node (SPEC §2.6 / §7). The session is torn down
    in __aexit__ regardless of outcome (SPEC VD-5) so a stuck or failed run
    cannot leak resources.
    """
    import tempfile

    from union import sandbox as sb

    with tempfile.TemporaryDirectory() as work:
        async with sb.on_device.session(host_work_dir=work, backend="userns") as sbx:
            await sbx.put_bytes(f"{work}/target.c", source.encode())

            compile_proc = await sbx.run(
                f"gcc -fno-stack-protector -w -o {work}/target {work}/target.c",
                stdout=True,
                stderr=True,
                timeout_s=60,
            )
            compile_out, compile_err = await compile_proc.communicate_text()
            log = compile_out + compile_err
            if "error" in log.lower():
                return {
                    "triggered": False,
                    "sandbox_exit_code": -1,
                    "log": f"COMPILE_FAILED\n{log}",
                }

            payload = "A" * int(poc["payload_len"])
            run_proc = await sbx.run(
                f"{work}/target {payload}",
                stdout=True,
                stderr=True,
                timeout_s=60,
            )
            run_out, run_err = await run_proc.communicate_text()
            log = run_out + "\n" + run_err
            triggered = "SIGSEGV" in log

            return {
                "triggered": bool(triggered),
                "sandbox_exit_code": getattr(run_proc, "returncode", 0),
                "log": log,
            }

# --- Agent + hypothesize task (depends on all tools above) ------------------
hypothesis_agent = Agent(
    name="autosec-hypothesis",
    instructions=ANALYSIS_INSTRUCTIONS,
    model=MODEL,
    tools=[scan_static, build_poc, validate_in_sandbox],
    max_turns=6,
)

@env.task(retries=3, timeout=20)
async def hypothesize(source: str, static_findings: str) -> dict:
    prompt = (
        "Analyze this C source file for memory-corruption vulnerabilities.\n\n"
        f"SOURCE:\n{source}\n\nDANGEROUS CALLS:\n{static_findings}\n"
    )

    # Beat A: hang on the first attempt -> task timeout -> retry.
    timeout_on = _force("AUTOSEC_FORCE_LLM_TIMEOUT") and _attempt() == 0
    bad_on = _force("AUTOSEC_FORCE_BAD_TOOL_CALL")

    if timeout_on:
        await asyncio.sleep(600)

    result = await hypothesis_agent.run.aio(prompt, memory=[])
    raw = result.summary or ""

    # Beat B: simulate a hallucinated/malformed tool call. When the timeout beat
    # is also active it consumes attempt 0, so defer this to attempt 1 — that way
    # both beats are actually demonstrated in a single run (e.g. AUTOSEC_FORCE_ALL).
    bad_attempt = 1 if _force("AUTOSEC_FORCE_LLM_TIMEOUT") else 0
    if bad_on and _attempt() == bad_attempt:
        raw = "Sure! The bug is somewhere around here, trust me."

    hyp = _extract_json(raw)
    if "vulnerable" not in hyp:
        hyp["vulnerable"] = "buffer_size" in hyp
    if hyp.get("vulnerable") and "buffer_size" not in hyp:
        raise ValueError(f"vulnerable hypothesis missing buffer_size: {hyp}")
    return hyp

# --- Orchestration ----------------------------------------------------------
def _load_targets() -> dict[str, str]:
    return {p.name: p.read_text() for p in sorted(TARGETS_DIR.glob("*.c"))}

@env.task
async def analyze_target(name: str, source: str) -> dict:
    findings = await scan_static(source)
    hypothesis = await hypothesize(source, findings)

    if not hypothesis.get("vulnerable"):
        poc: dict = {}
        verdict = {"triggered": False, "skipped": True}
    else:
        poc = await build_poc(hypothesis)
        verdict = await validate_in_sandbox(source, poc)

    return {
        "target": name,
        "static_findings": findings,
        "hypothesis": hypothesis,
        "poc": poc,
        "verdict": verdict,
    }

_REPORT_CSS = """
<style>
  .autosec { --bg:#ffffff; --card:#f7f8fa; --line:#e3e7ec; --muted:#5b6675;
    --text:#1b2330; --red:#c0392b; --amber:#b6791f; --green:#1e7e34; --accent:#1f6feb;
    font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;
    background:var(--bg); color:var(--text); padding:24px; border-radius:12px;
    border:1px solid var(--line); }
  .autosec h2 { margin:0 0 4px; font-size:20px; letter-spacing:.2px; }
  .autosec .sub { color:var(--muted); font-size:13px; margin:0 0 20px; }
  .autosec .cards { display:flex; gap:12px; flex-wrap:wrap; margin-bottom:22px; }
  .autosec .card { background:var(--card); border:1px solid var(--line);
    border-radius:10px; padding:14px 18px; min-width:120px; }
  .autosec .card .n { font-size:26px; font-weight:700; line-height:1; }
  .autosec .card .l { color:var(--muted); font-size:12px; margin-top:6px;
    text-transform:uppercase; letter-spacing:.6px; }
  .autosec table { width:100%; table-layout:fixed; border-collapse:collapse; font-size:13px;
    background:#fff; border:1px solid var(--line); border-radius:10px; overflow:hidden; }
  .autosec th, .autosec td { overflow-wrap:anywhere; }
  .autosec thead th { background:#eef1f5; color:var(--muted); text-align:left;
    font-weight:600; font-size:11px; text-transform:uppercase; letter-spacing:.6px;
    padding:11px 14px; border-bottom:1px solid var(--line); }
  .autosec tbody td { padding:11px 14px; border-bottom:1px solid var(--line);
    vertical-align:top; }
  .autosec tbody tr:last-child td { border-bottom:none; }
  .autosec tbody tr:nth-child(even) { background:#fafbfc; }
  .autosec tbody tr:hover { background:#eef4ff; }
  .autosec code { background:#eef1f5; border:1px solid var(--line); border-radius:5px;
    padding:1px 6px; font-family:ui-monospace,SFMono-Regular,Menlo,monospace; font-size:12px; }
  .autosec .num { text-align:right; font-variant-numeric:tabular-nums; }
  .autosec .reason { color:var(--muted); line-height:1.45; }
  .autosec .badge { display:inline-block; padding:3px 10px; border-radius:999px;
    font-size:11px; font-weight:700; letter-spacing:.4px; white-space:nowrap; }
  .autosec .b-exploited { background:rgba(192,57,43,.10); color:var(--red);
    border:1px solid rgba(192,57,43,.35); }
  .autosec .b-vuln { background:rgba(182,121,31,.12); color:var(--amber);
    border:1px solid rgba(182,121,31,.35); }
  .autosec .b-secure { background:rgba(30,126,52,.10); color:var(--green);
    border:1px solid rgba(30,126,52,.35); }
  .autosec col.c-target  { width:15%; }
  .autosec col.c-status  { width:11%; }
  .autosec col.c-class   { width:11%; }
  .autosec col.c-fn      { width:11%; }
  .autosec col.c-buf     { width:8%; }
  .autosec col.c-payload { width:8%; }
  .autosec col.c-exit    { width:6%; }
  .autosec col.c-reason  { width:30%; }
  .autosec .kv { display:flex; flex-wrap:wrap; gap:12px 28px; margin:16px 0 4px; }
  .autosec .kv .k { color:var(--muted); font-size:11px; text-transform:uppercase; letter-spacing:.6px; }
  .autosec .kv .v { font-weight:600; font-size:14px; margin-top:3px; }
  .autosec .section-label { font-size:11px; text-transform:uppercase; letter-spacing:.6px;
    color:var(--muted); margin:22px 0 7px; }
  .autosec .reason-block { background:var(--card); border:1px solid var(--line);
    border-radius:8px; padding:13px 15px; line-height:1.5; font-size:13px; }
  .autosec pre.code { background:#f6f8fa; border:1px solid var(--line); border-radius:8px;
    padding:14px 16px; overflow:auto; font-family:ui-monospace,SFMono-Regular,Menlo,monospace;
    font-size:12.5px; line-height:1.5; color:#1b2330; margin:0; }
  .autosec .subtabs > input[type=radio] { position:absolute; opacity:0; pointer-events:none; }
  .autosec .subnav { display:flex; flex-wrap:wrap; gap:4px; border-bottom:1px solid var(--line);
    margin:8px 0 18px; }
  .autosec .subnav label { display:inline-flex; align-items:center; gap:8px; padding:8px 14px;
    font-size:12.5px; cursor:pointer; border:1px solid transparent; border-bottom:none;
    border-radius:8px 8px 0 0; color:var(--muted); margin-bottom:-1px; }
  .autosec .subnav label:hover { background:var(--card); color:var(--text); }
  .autosec .subnav .dot { width:8px; height:8px; border-radius:50%; }
  .autosec .dot.b-exploited { background:var(--red); }
  .autosec .dot.b-vuln { background:var(--amber); }
  .autosec .dot.b-secure { background:var(--green); }
  .autosec .panel { display:none; }
</style>
"""

def _status(finding: dict) -> tuple[str, str]:
    hyp = finding.get("hypothesis") or {}
    verdict = finding.get("verdict") or {}
    if not hyp.get("vulnerable"):
        return "b-secure", "SECURE"
    if verdict.get("triggered"):
        return "b-exploited", "EXPLOITED"
    return "b-vuln", "VULNERABLE"

def _render_report_html(findings: list[dict]) -> str:
    exploited = sum(1 for f in findings if (f.get("verdict") or {}).get("triggered"))
    vulnerable = sum(1 for f in findings if (f.get("hypothesis") or {}).get("vulnerable"))
    secure = len(findings) - vulnerable

    rows = []
    for f in sorted(findings, key=lambda x: x["target"]):
        hyp = f.get("hypothesis") or {}
        verdict = f.get("verdict") or {}
        cls, label = _status(f)
        is_vuln = bool(hyp.get("vulnerable"))
        vuln_class = hyp.get("vuln_class", "\u2014") if is_vuln else "\u2014"
        fn = hyp.get("function", "\u2014") if is_vuln else "\u2014"
        buf = hyp.get("buffer_size", "\u2014") if is_vuln else "\u2014"
        payload = (f.get("poc") or {}).get("payload_len", "\u2014") if is_vuln else "\u2014"
        exit_code = verdict.get("sandbox_exit_code", "\u2014")
        rows.append(
            "<tr>"
            f"<td><code>{html.escape(str(f['target']))}</code></td>"
            f'<td><span class="badge {cls}">{label}</span></td>'
            f"<td>{html.escape(str(vuln_class))}</td>"
            f"<td><code>{html.escape(str(fn))}</code></td>"
            f'<td class="num">{html.escape(str(buf))}</td>'
            f'<td class="num">{html.escape(str(payload))}</td>'
            f'<td class="num">{html.escape(str(exit_code))}</td>'
            f'<td class="reason">{html.escape(str(hyp.get("reasoning", "")))}</td>'
            "</tr>"
        )

    return f"""{_REPORT_CSS}
    <div class="autosec">
      <h2>AutoSec &middot; security findings report</h2>
      <p class="sub">{len(findings)} target(s) analyzed in parallel &middot; PoCs validated in isolated sandbox.</p>
      <div class="cards">
        <div class="card"><div class="n">{len(findings)}</div><div class="l">Targets</div></div>
        <div class="card"><div class="n" style="color:#ff6b6b">{exploited}</div><div class="l">Exploited</div></div>
        <div class="card">
          <div class="n" style="color:#ffb454">{vulnerable - exploited}</div>
          <div class="l">Vuln, PoC failed</div>
        </div>
        <div class="card"><div class="n" style="color:#3fb950">{secure}</div><div class="l">Secure</div></div>
      </div>
      <table>
        <colgroup>
          <col class="c-target"><col class="c-status"><col class="c-class"><col class="c-fn">
          <col class="c-buf"><col class="c-payload"><col class="c-exit"><col class="c-reason">
        </colgroup>
        <thead>
          <tr>
            <th>Target</th><th>Status</th><th>Vuln class</th><th>Function</th>
            <th class="num">Buffer&nbsp;(B)</th><th class="num">Payload&nbsp;(B)</th>
            <th class="num">Exit</th><th>Reasoning</th>
          </tr>
        </thead>
        <tbody>{"".join(rows)}</tbody>
      </table>
    </div>
    """

def _target_detail_html(finding: dict, source: str) -> str:
    hyp = finding.get("hypothesis") or {}
    verdict = finding.get("verdict") or {}
    cls, label = _status(finding)
    is_vuln = bool(hyp.get("vulnerable"))

    def cell(k: str, v: Any) -> str:
        return f'<div><div class="k">{html.escape(k)}</div><div class="v">{html.escape(str(v))}</div></div>'

    triggered = verdict.get("triggered")
    verdict_txt = "skipped (secure)" if verdict.get("skipped") else ("triggered" if triggered else "not triggered")
    stats = "".join(
        [
            cell("Vuln class", hyp.get("vuln_class", "\u2014") if is_vuln else "\u2014"),
            cell("Function", hyp.get("function", "\u2014") if is_vuln else "\u2014"),
            cell("Buffer (B)", hyp.get("buffer_size", "\u2014") if is_vuln else "\u2014"),
            cell("Payload (B)", (finding.get("poc") or {}).get("payload_len", "\u2014") if is_vuln else "\u2014"),
            cell("Sandbox exit", verdict.get("sandbox_exit_code", "\u2014")),
            cell("PoC", verdict_txt),
        ]
    )
    reasoning = html.escape(str(hyp.get("reasoning", "")) or "\u2014")
    code = html.escape(source or "(source unavailable)")
    return f"""
      <h3 style="margin:0 0 4px"><code>{html.escape(str(finding["target"]))}</code>
          &nbsp;<span class="badge {cls}">{label}</span></h3>
      <p class="sub">Per-target detail &middot; PoCs validated in an isolated sandbox.</p>
      <div class="kv">{stats}</div>
      <div class="section-label">Reasoning</div>
      <div class="reason-block">{reasoning}</div>
      <div class="section-label">Source</div>
      <pre class="code">{code}</pre>
    """

def _render_targets_tab_html(findings: list[dict], sources: dict[str, str]) -> str:
    ordered = sorted(findings, key=lambda x: x["target"])

    radios, nav, panels, rules = [], [], [], []
    for i, f in enumerate(ordered):
        name = f["target"]
        cls, _ = _status(f)
        checked = " checked" if i == 0 else ""
        radios.append(f'<input type="radio" name="as-targets" id="as-t{i}"{checked}>')
        nav.append(f'<label for="as-t{i}"><span class="dot {cls}"></span><code>{html.escape(str(name))}</code></label>')
        panels.append(f'<div class="panel" id="as-p{i}">{_target_detail_html(f, sources.get(name, ""))}</div>')
        rules.append(
            f'.autosec #as-t{i}:checked ~ .subnav label[for="as-t{i}"]'
            "{background:#fff;color:var(--text);border-color:var(--line);font-weight:600;}"
            f".autosec #as-t{i}:checked ~ .panels #as-p{i}{{display:block;}}"
        )

    return f"""{_REPORT_CSS}
    <style>{"".join(rules)}</style>
    <div class="autosec">
      <h2>AutoSec &middot; target detail</h2>
      <p class="sub">{len(ordered)} target(s) &middot; select a file to see its status, reasoning, and source.</p>
      <div class="subtabs">
        {"".join(radios)}
        <div class="subnav">{"".join(nav)}</div>
        <div class="panels">{"".join(panels)}</div>
      </div>
    </div>
    """

@env.task(retries=1)
async def random_error() -> str:
    if _attempt() == 0:
        raise Exception("Random error")
    return "Passed!"

# {{docs-fragment pipeline}}
@env.task(report=True)
async def run_autosec_agent() -> dict:
    targets = _load_targets()
    if not targets:
        raise FileNotFoundError(f"no targets found under {TARGETS_DIR}")

    findings = list(await asyncio.gather(*(analyze_target(name, src) for name, src in targets.items())))

    await flyte.report.replace.aio(_render_report_html(findings))
    flyte.report.get_tab("targets").replace(_render_targets_tab_html(findings, targets))
    await flyte.report.flush.aio()

    await random_error()

    return {
        "targets_analyzed": len(findings),
        "triggered": sum(1 for f in findings if f["verdict"].get("triggered")),
        "findings": findings,
    }
# {{/docs-fragment pipeline}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(run_autosec_agent)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/autosec_research_agent/main.py*

## Run the agent

### Create secrets

Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/) and register it as a Flyte secret:

```
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/autosec_research_agent):

```
cd v2/tutorials/autosec_research_agent
uv run --script main.py
```

Follow the printed run URL to watch each target progress through the pipeline and open the report panel for the findings table and per-target detail tabs.

Optional environment variables demonstrate self-healing behavior (`AUTOSEC_FORCE_LLM_TIMEOUT`, `AUTOSEC_FORCE_BAD_TOOL_CALL`, `AUTOSEC_FORCE_OOM`, or `AUTOSEC_FORCE_ALL=1`).

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/code-agent ===

# Coding agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/code_runner).

This example demonstrates how to run code generated by a large language model (LLM) using a `ContainerTask`.
The agent takes a user’s question, generates Flyte 2 code using the Flyte 2 documentation as context, and runs it in an isolated container.
If the execution fails, the agent reflects on the error and retries
up to a configurable limit until it succeeds.

Using `ContainerTask` ensures that all generated code runs in a secure environment.
This gives you full flexibility to execute arbitrary logic safely and reliably.

## What this example demonstrates

- How to combine LLM generation with programmatic execution.
- How to run untrusted or dynamically generated code securely.
- How to iteratively improve code using agent-like behavior.

## Setting up the agent environment

Let's start by importing the necessary libraries and setting up two environments: one for the container task and another for the agent task.
This example follows the `uv` script format to declare dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b23",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# ///
```

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

> [!NOTE]
> You can set up access to the OpenAI API using a Flyte secret.
>
> ```
> flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
> ```

We store the LLM-generated code in a structured format. This allows us to:

- Enforce consistent formatting
- Make debugging easier
- Log and analyze generations systematically

By capturing metadata alongside the raw code, we maintain transparency and make it easier to iterate or trace issues over time.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

We then define a state model to persist the agent's history across iterations. This includes previous messages,
generated code, and any errors encountered.

Maintaining this history allows the agent to reflect on past attempts, avoid repeating mistakes,
and iteratively improve the generated code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Retrieve docs

We define a task to load documents from a given URL and concatenate them into a single string.
This string is then used as part of the LLM prompt.

We set `max_depth = 20` to avoid loading an excessive number of documents.
However, even with this limit, the resulting context can still be quite large.
To handle this, we use an LLM (GPT-4 in this case) that supports extended context windows.

> [!NOTE]
> Appending all documents into a single string can result in extremely large contexts, potentially exceeding the LLM’s token limit.
> If your dataset grows beyond what a single prompt can handle, there are a couple of strategies you can use.
> One option is to apply Retrieval-Augmented Generation (RAG), where you chunk the documents, embed them using a model,
> store the vectors in a vector database, and retrieve only the most relevant pieces at inference time.
>
> An alternative approach is to pass references to full files into the prompt, allowing the LLM to decide which files are most relevant based
> on natural-language search over file paths, summaries, or even contents. This method assumes that only a subset of files
> will be necessary for a given task, and the LLM is responsible for navigating the structure and identifying what to read.
> While this can be a lighter-weight solution for smaller datasets, its effectiveness depends on how well the LLM can
> reason over file references and the reliability of its internal search heuristics.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Code generation

Next, we define a utility function to construct the LLM chain responsible for generating Python code from user input. This chain leverages
a LangChain `PromptTemplate` to structure the input and an OpenAI chat model to generate well-formed, Flyte 2-compatible Python scripts.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

We then define a `generate` task responsible for producing the code solution.
To improve clarity and testability, the output is structured in three parts:
a short summary of the generated solution, a list of necessary imports,
and the main body of executable code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

A `ContainerTask` then executes this code in an isolated container environment.
It takes the code as input, runs it safely, and returns the program’s output and exit code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

This task verifies that the generated code runs as expected.
It tests the import statements first, then executes the full code.
It records the output and any error messages in the agent state for further analysis.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

If an error occurs, a separate task reflects on the failure and generates a response.
This reflection is added to the agent state to guide future iterations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

Finally, we define a `main` task that runs the code agent and orchestrates the steps above.
If the code execution fails, we reflect on the error and retry until we reach the maximum number of iterations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Running the code agent

If things are working properly, you should see output similar to the following:

```
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: PASSED---
---NO CODE TEST FAILURES---
---DECISION: FINISH---
In this solution, we define two tasks using Flyte v2.
The first task, `oomer`, is designed to simulate an out-of-memory (OOM) error by attempting to allocate a large list.
The second task, `failure_recovery`, attempts to execute `oomer` and catches any OOM errors.
If an OOM error is caught, it retries the `oomer` task with increased memory resources.
This pattern demonstrates how to handle resource-related exceptions and dynamically adjust task configurations in Flyte workflows.

import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="oom_example", resources=flyte.Resources(cpu=1, memory="250Mi"))

@env.task
async def oomer(x: int):
    large_list = [0] * 100000000  # Simulate OOM
    print(len(large_list))

@env.task
async def always_succeeds() -> int:
    await asyncio.sleep(1)
    return 42

...
```

You can run the code agent on a Flyte/Union cluster using the following command:

```
uv run agent.py
```

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/competitive-intelligence-agent ===

# Competitive intelligence agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/competitive_intelligence_agent).

This example demonstrates how to build a continuous competitive and market intelligence agent on Flyte. The agent fans out across a list of competitors, pulls fresh, source-cited web and news results from the [You.com Search API](https://you.com/docs/search/overview), and uses [Claude](https://docs.anthropic.com/) via [LiteLLM](https://docs.litellm.ai/) to extract structured **deltas** — pricing changes, product launches, funding events, leadership moves, and more — into a knowledge-graph-ready table.

You.com returns ranked web and news results with snippets and publication timestamps, giving the LLM attributable sources to cite. Flyte orchestrates the rest:

- **Fan-out parallelism** across competitors with `asyncio.gather`
- **`cache="auto"`** so converging parallel or repeat runs reuse prior You.com and LLM results when queries overlap
- **`@flyte.trace`** on every You.com and LLM call for full prompt → query → source lineage
- **Flyte reports** that render an HTML dashboard grouping deltas by competitor and category

![Competitive intelligence agent report](https://www.union.ai/docs/v2/union/_static/images/tutorials/competitive_intelligence_agent/competitive-intelligence-agent.png)

## Setting up the environment

The agent runs in a single `TaskEnvironment` with secrets for the You.com and Anthropic API keys, automatic caching, and a container image built from the `uv` script dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# ///
```

## Data types

The agent models search hits, deltas, and the final report as dataclasses. Each `Delta` links back to a `SearchHit` that preserves You.com metadata — domain, publication date, author, and snippet.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

## Search with the You.com Search API

The `you_search` helper calls the [You.com Search API](https://you.com/docs/search/overview) at `https://ydc-index.io/v1/search`. It requests unified web and news results with a `freshness` filter (`day`, `week`, `month`, or `year`) and returns structured hits the LLM can cite by index.

See the [Search API reference](https://you.com/docs/api-reference/search/v1-search) for all supported parameters, including `count`, `country`, and search operators.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

> [!NOTE]
> We use `@flyte.trace` to track intermediate steps within a task, like You.com API calls and LLM invocations. Each traced call appears as a span in the Flyte dashboard with its inputs and outputs captured.

## Extract deltas with Claude

A shared `llm_json` helper routes to Claude through LiteLLM and parses structured JSON from the response.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

## Watch one competitor

The `watch_competitor` task builds a category-scoped search query, calls the You.com Search API, and asks Claude to extract only changes that are supported by a specific search result. Each delta carries a confidence score and a link to its source hit.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

## Orchestration

The `competitive_intelligence` driver task fans out across all competitors with `asyncio.gather`, aggregates the results, and renders a Flyte report.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "competitive_intelligence"
# params = ""
# ///
"""Continuous competitive & market intelligence agent.

A Dragonfly-style agent that fans out across competitors, pulls fresh,
source-cited web + news results from the You.com Search API, and uses Claude to
extract structured "deltas" (pricing, features, funding, leadership, etc.) into
a knowledge-graph-ready table.
"""

# {{docs-fragment env}}
import asyncio
import json
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="competitive-intelligence",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="competitive-intelligence", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class SearchHit:
    """A You.com Search result with its full structured metadata."""

    title: str
    url: str
    domain: str
    snippet: str
    published: str  # You.com page_age timestamp
    author: str
    favicon: str  # You.com favicon_url
    thumbnail: str
    section: str  # "news" or "web" — You.com's auto classification

@dataclass
class Delta:
    competitor: str
    category: str
    summary: str
    confidence: float
    source: SearchHit | None = None

@dataclass
class CompetitorWatch:
    competitor: str
    deltas: list[Delta] = field(default_factory=list)
    sources: list[SearchHit] = field(default_factory=list)

@dataclass
class IntelReport:
    watches: list[CompetitorWatch] = field(default_factory=list)

    @property
    def deltas(self) -> list[Delta]:
        return [d for w in self.watches for d in w.deltas]
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import os
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

@flyte.trace
async def you_search(query: str, count: int = 8, freshness: str = "week") -> list[SearchHit]:
    """Call the You.com Search API and return unified web + news hits."""
    params = {"query": query, "count": count, "freshness": freshness}
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    thumbnail=item.get("thumbnail_url", "") or "",
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict | list:
    """Call Claude via LiteLLM and parse a JSON response."""
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=2048,
    )
    content = resp.choices[0].message.content
    return _parse_json(content)

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min(
        (i for i in (text.find("{"), text.find("[")) if i != -1),
        default=0,
    )
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

EXTRACT_SYSTEM = """You are a competitive-intelligence analyst. Given fresh \
search results about a competitor, extract concrete, recently-changed signals \
("deltas") in the requested categories. Only report changes that are supported \
by a specific search result. Respond with a JSON object of the form:
{"deltas": [{"category": str, "summary": str, "source_index": int (the [n] of \
the supporting search result), "confidence": float between 0 and 1}]}
If there are no clear changes, return {"deltas": []}."""

# {{docs-fragment watch_competitor}}
@env.task(retries=3)
async def watch_competitor(
    competitor: str,
    categories: list[str],
    freshness: str,
) -> CompetitorWatch:
    """Search for fresh signals on one competitor and extract structured deltas."""
    query = (
        f"{competitor} "
        + " OR ".join(categories)
        + " announcement OR news OR update"
    )
    hits = await you_search(query, count=8, freshness=freshness)
    if not hits:
        return CompetitorWatch(competitor=competitor)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Competitor: {competitor}\n"
        f"Categories to watch: {', '.join(categories)}\n\n"
        f"Search results:\n{evidence}"
    )
    parsed = await llm_json(EXTRACT_SYSTEM, user)
    raw_deltas = parsed.get("deltas", []) if isinstance(parsed, dict) else []

    deltas: list[Delta] = []
    cited: list[SearchHit] = []
    for d in raw_deltas:
        idx = int(d.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        if src is not None and src not in cited:
            cited.append(src)
        deltas.append(
            Delta(
                competitor=competitor,
                category=str(d.get("category", "unknown")),
                summary=str(d.get("summary", "")),
                confidence=float(d.get("confidence", 0.0) or 0.0),
                source=src,
            )
        )
    return CompetitorWatch(competitor=competitor, deltas=deltas, sources=cited)
# {{/docs-fragment watch_competitor}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .card h2 { font-size:16px; margin:0 0 6px; color:#102a43; }
  .rpt .row { padding:11px 0; border-top:1px solid #f0f2f5; }
  .rpt .row:first-of-type { border-top:none; }
  .rpt .chip { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; white-space:nowrap;
               text-transform:uppercase; letter-spacing:.03em;
               background:#e0e8f9; color:#2b4ba0; margin-right:8px; }
  .rpt .summary { margin:6px 0 4px; font-size:14px; line-height:1.45; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .bar { display:inline-block; width:60px; height:6px; border-radius:3px;
              background:#e4e7eb; vertical-align:middle; overflow:hidden;
              margin-right:6px; }
  .rpt .bar > span { display:block; height:100%; background:#3ebd93; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:8px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .src-head { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#627d98; margin:14px 0 4px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _conf_bar(conf: float) -> str:
    pct = max(0, min(100, int(conf * 100)))
    return (
        f"<span class='bar'><span style='width:{pct}%'></span></span>"
        f"<span class='meta'>{conf:.0%} confidence</span>"
    )

def _cite(src: SearchHit) -> str:
    """Render a rich You.com citation: favicon, domain, date, author, snippet."""
    if src is None:
        return ""
    tag = (
        f"<span class='ctag news'>news</span>"
        if src.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if src.published:
        meta_bits.append(src.published[:10])
    if src.author:
        meta_bits.append(f"by {src.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{src.snippet}&rdquo;</div>" if src.snippet else ""
    return (
        f"<div class='cite'>"
        f"<img class='fav' src='{src.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{src.url}'><span class='cdom'>{src.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: IntelReport) -> str:
    watches = sorted(report.watches, key=lambda w: w.competitor)
    total_sources = sum(len(w.sources) for w in watches)

    cards = []
    for w in watches:
        deltas = sorted(w.deltas, key=lambda d: -d.confidence)
        rows = "".join(
            f"<div class='row'><span class='chip'>{d.category}</span>"
            f"<div class='summary'>{d.summary}</div>"
            f"{_conf_bar(d.confidence)}"
            f"{_cite(d.source)}"
            "</div>"
            for d in deltas
        )
        cards.append(
            f"<div class='card'><h2>{w.competitor}</h2>"
            f"<span class='meta'>{len(deltas)} signal(s) &middot; "
            f"{len(w.sources)} You.com source(s)</span>{rows or ''}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Competitive Intelligence Deltas</h1>
      <p class="sub">Fresh, source-cited market signals — every delta links back
      to a ranked, timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.deltas)}</b> signals</span>
        <span class="pill"><b>{len(watches)}</b> competitors tracked</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No signals detected in this window.</p>"}
      <p class="yoube">Sources retrieved and ranked by the You.com Search API
      (web + auto-classified news), with publication timestamps, authors, and
      snippet provenance preserved for full prompt &rarr; citation lineage.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
@env.task(report=True)
async def competitive_intelligence(
    competitors: list[str] = [
        "Anthropic",
        "OpenAI",
        "Mistral AI",
        "Google DeepMind",
        "Cohere",
        "Perplexity AI",
        "xAI",
        "Hugging Face",
        "Databricks",
        "Together AI",
    ],
    categories: list[str] = [
        "pricing",
        "product launch",
        "model release",
        "funding",
        "leadership",
        "partnership",
    ],
    freshness: str = "week",
) -> IntelReport:
    """Fan out across competitors and aggregate structured deltas."""
    with flyte.group("watch-competitors"):
        results = await asyncio.gather(
            *[watch_competitor(c, categories, freshness) for c in competitors]
        )

    report = IntelReport(watches=list(results))

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(competitive_intelligence)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/competitive_intelligence_agent/main.py*

## Run the agent

### Create secrets

Get a You.com API key from the [You.com platform](https://you.com/platform) (see the [quickstart guide](https://you.com/docs/quickstart)). Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/).

Register both keys as Flyte secrets. The secret key names must match those declared in the `TaskEnvironment`:

```
flyte create secret youdotcom-api-key <YOUR_YOU_API_KEY>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run locally or remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/competitive_intelligence_agent):

```
cd v2/tutorials/competitive_intelligence_agent
uv run --script main.py
```

Or pass custom competitors with the Flyte CLI:

```
flyte run main.py competitive_intelligence \
  --competitors '["Anthropic", "OpenAI"]'
```

To test locally without Flyte secrets, export the environment variables directly:

```
export YOU_API_KEY=<YOUR_YOU_API_KEY>
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>

uv run --script main.py
```

When the run completes, open the Flyte report in the UI to review deltas grouped by competitor, each with a clickable You.com source citation.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/deep-research ===

# Deep research

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/deep_research_agent); based on work by [Together AI](https://github.com/togethercomputer/open_deep_research).

This example demonstrates how to build an agentic workflow for deep research—a multi-step reasoning system that mirrors how a human researcher explores, analyzes, and synthesizes information from the web.

Deep research refers to the iterative process of thoroughly investigating a topic: identifying relevant sources, evaluating their usefulness, refining the research direction, and ultimately producing a well-structured summary or report. It's a long-running task that requires the agent to reason over time, adapt its strategy, and chain multiple steps together, making it an ideal fit for an agentic architecture.

In this example, we use:

- [Tavily](https://www.tavily.com/) to search for and retrieve high-quality online resources.
- [LiteLLM](https://litellm.ai/) to route LLM calls that perform reasoning, evaluation, and synthesis.

The agent executes a multi-step trajectory:

- Parallel search across multiple queries.
- Evaluation of retrieved results.
- Adaptive iteration: If results are insufficient, it formulates new research queries and repeats the search-evaluate cycle.
- Synthesis: After a fixed number of iterations, it produces a comprehensive research report.

What makes this workflow compelling is its dynamic, evolving nature. The agent isn't just following a fixed plan; it's making decisions in context, using multiple prompts and reasoning steps to steer the process.

Flyte is uniquely well-suited for this kind of system. It provides:

- Structured composition of dynamic reasoning steps
- Built-in parallelism for faster search and evaluation
- Traceability and observability into each step and iteration
- Scalability for long-running or compute-intensive workloads

![Result](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/deep-research/result.gif)

Throughout this guide, we'll show how to design this workflow using the Flyte SDK, and how to unlock the full potential of agentic development with tools you already know and trust.

## Setting up the environment

Let's begin by setting up the task environment. We define the following components:

- Secrets for Together and Tavily API keys
- A custom image with required Python packages and apt dependencies (`pandoc`, `texlive-xetex`)
- External YAML file with all LLM prompts baked into the container

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b6",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# ///
```

## Generate research queries

This task converts a user prompt into a list of focused queries. It makes two LLM calls to generate a high-level research plan and parse that plan into atomic search queries.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

LLM calls use LiteLLM, and each is wrapped with `flyte.trace` for observability:

```
from typing import Any, AsyncIterator, Optional

from litellm import acompletion, completion

import flyte

# {{docs-fragment asingle_shot_llm_call}}
@flyte.trace
async def asingle_shot_llm_call(
    model: str,
    system_prompt: str,
    message: str,
    response_format: Optional[dict[str, str | dict[str, Any]]] = None,
    max_completion_tokens: int | None = None,
) -> AsyncIterator[str]:
    stream = await acompletion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        temperature=0.0,
        response_format=response_format,
        # NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
        # NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
        max_tokens=max_completion_tokens,
        timeout=600,
        stream=True,
    )
    async for chunk in stream:
        content = chunk.choices[0].delta.get("content", "")
        if content:
            yield content

# {{/docs-fragment asingle_shot_llm_call}}

def single_shot_llm_call(
    model: str,
    system_prompt: str,
    message: str,
    response_format: Optional[dict[str, str | dict[str, Any]]] = None,
    max_completion_tokens: int | None = None,
) -> str:
    response = completion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        temperature=0.0,
        response_format=response_format,
        # NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
        # NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
        max_tokens=max_completion_tokens,
        timeout=600,
    )
    return response.choices[0].message["content"]  # type: ignore
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/libs/utils/llms.py*

> [!NOTE]
> We use `flyte.trace` to track intermediate steps within a task, like LLM calls or specific function executions. This lightweight decorator adds observability with minimal overhead and is especially useful for inspecting reasoning chains during task execution.

## Search and summarize

We submit each research query to Tavily and summarize the results using an LLM. We run all summarization tasks with `asyncio.gather`, which signals to Flyte that these tasks can be distributed across separate compute resources.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Evaluate research completeness

Now we assess whether the gathered research is sufficient. Again, the task uses two LLM calls to evaluate the completeness of the results and propose additional queries if necessary.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Filter results

In this step, we evaluate the relevance of search results and rank them. This task returns the most useful sources for the final synthesis.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Generate the final answer

Finally, we generate a detailed research report by synthesizing the top-ranked results. This is the output returned to the user.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Orchestration

Next, we define a `research_topic` task to orchestrate the entire deep research workflow. It runs the core stages in sequence: generating research queries, performing search and summarization, evaluating the completeness of results, and producing the final report.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

The `main` task wraps this entire pipeline and adds report generation in HTML format as the final step.
It also serves as the main entry point to the workflow, allowing us to pass in all configuration parameters, including which LLMs to use at each stage.
This flexibility lets us mix and match models for planning, summarization, and final synthesis, helping us optimize for both cost and quality.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Run the deep research agent

First, create the required secrets:

```
flyte create secret TOGETHER_API_KEY <>
flyte create secret TAVILY_API_KEY <>
```

Run the agent:

```
uv run agent.py
```

If you want to test it locally first, run the following commands:

```
brew install pandoc
brew install basictex # restart your terminal after install

export TOGETHER_API_KEY=<>
export TAVILY_API_KEY=<>

uv run agent.py
```

## Evaluate with Weights & Biases Weave

We use W&B Weave to evaluate the full agent pipeline and analyze LLM-generated responses. The evaluation runs as a Flyte pipeline and uses an LLM-as-a-judge scorer to measure the quality of LLM-generated responses.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "weave==0.51.51",
#    "datasets==3.6.0",
#    "huggingface-hub==0.32.6",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
# ]
# ///

import os

import weave
from agent import research_topic
from datasets import load_dataset
from huggingface_hub import login
from libs.utils.log import AgentLogger
from litellm import completion

import flyte

logging = AgentLogger()

weave.init(project_name="deep-researcher")

env = flyte.TaskEnvironment(name="deep-researcher-eval")

@weave.op
def llm_as_a_judge_scoring(answer: str, output: str, question: str) -> bool:
    prompt = f"""
    Given the following question and answer, evaluate the answer against the correct answer:

    <question>
    {question}
    </question>

    <agent_answer>
    {output}
    </agent_answer>

    <correct_answer>
    {answer}
    </correct_answer>

    Note that the agent answer might be a long text containing a lot of information or it might be a short answer.

    You should read the entire text and think if the agent answers the question somewhere
    in the text. You should try to be flexible with the answer but careful.

    For example, answering with names instead of name and surname is fine.

    The important thing is that the answer of the agent either contains the correct answer or is equal to
    the correct answer.

    <reasoning>
    The agent answer is correct because I can read that ....
    </reasoning>

    <answer>
    1
    </answer>

    Otherwise, return

    <reasoning>
    The agent answer is incorrect because there is ...
    </reasoning>

    <answer>
    0
    </answer>

    """

    messages = [
        {
            "role": "system",
            "content": "You are an helpful assistant that returns a number between 0 and 1.",
        },
        {"role": "user", "content": prompt},
    ]
    answer = (
        completion(
            model="together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
            messages=messages,
            max_tokens=1000,
            temperature=0.0,
        )
        .choices[0]  # type: ignore
        .message["content"]  # type: ignore
    )

    return bool(int(answer.split("<answer>")[1].split("</answer>")[0].strip()))

def authenticate_huggingface():
    """Authenticate with Hugging Face Hub using token from environment variable."""
    token = os.getenv("HUGGINGFACE_TOKEN")
    if not token:
        raise ValueError(
            "HUGGINGFACE_TOKEN environment variable not set. "
            "Please set it with your token from https://huggingface.co/settings/tokens"
        )

    try:
        login(token=token)
        print("Successfully authenticated with Hugging Face Hub")
    except Exception as e:
        raise RuntimeError(f"Failed to authenticate with Hugging Face Hub: {e!s}")

@env.task
async def load_questions(
    dataset_names: list[str] | None = None,
) -> list[dict[str, str]]:
    """
    Load questions from the specified Hugging Face dataset configurations.

    Args:
        dataset_names: List of dataset configurations to load
                      Options:
                          "smolagents:simpleqa",
                          "hotpotqa",
                          "simpleqa",
                          "together-search-bench"
                      If None, all available configurations except hotpotqa will be loaded

    Returns:
        List of question-answer pairs
    """
    if dataset_names is None:
        dataset_names = ["smolagents:simpleqa"]

    all_questions = []

    # Authenticate with Hugging Face Hub (once and for all)
    authenticate_huggingface()

    for dataset_name in dataset_names:
        print(f"Loading dataset: {dataset_name}")

        try:
            if dataset_name == "together-search-bench":
                # Load Together-Search-Bench dataset
                dataset_path = "togethercomputer/together-search-bench"
                ds = load_dataset(dataset_path)
                if "test" in ds:
                    split_data = ds["test"]
                else:
                    print(f"No 'test' split found in dataset at {dataset_path}")
                    continue

                for i in range(len(split_data)):
                    item = split_data[i]
                    question_data = {
                        "question": item["question"],
                        "answer": item["answer"],
                        "dataset": item.get("dataset", "together-search-bench"),
                    }
                    all_questions.append(question_data)

                print(f"Loaded {len(split_data)} questions from together-search-bench dataset")
                continue

            elif dataset_name == "hotpotqa":
                # Load HotpotQA dataset (using distractor version for validation)
                ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True)
                split_name = "validation"
            elif dataset_name == "simpleqa":
                ds = load_dataset("basicv8vc/SimpleQA")
                split_name = "test"
            else:
                # Strip "smolagents:" prefix when loading the dataset
                actual_dataset = dataset_name.split(":")[-1]
                ds = load_dataset("smolagents/benchmark-v1", actual_dataset)
                split_name = "test"

        except Exception as e:
            print(f"Failed to load dataset {dataset_name}: {e!s}")
            continue  # Skip this dataset if it fails to load

        print(f"Dataset structure for {dataset_name}: {ds}")
        print(f"Available splits: {list(ds)}")

        split_data = ds[split_name]  # type: ignore

        for i in range(len(split_data)):
            item = split_data[i]

            if dataset_name == "hotpotqa":
                # we remove questions that are easy or medium (if any) just to reduce the number of questions
                if item["level"] != "hard":
                    continue

                question_data = {
                    "question": item["question"],
                    "answer": item["answer"],
                    "dataset": dataset_name,
                }
            elif dataset_name == "simpleqa":
                # Handle SimpleQA dataset format
                question_data = {
                    "question": item["problem"],
                    "answer": item["answer"],
                    "dataset": dataset_name,
                }
            else:
                question_data = {
                    "question": item["question"],
                    "answer": item["true_answer"],
                    "dataset": dataset_name,
                }

            all_questions.append(question_data)

    print(f"Loaded {len(all_questions)} questions in total")
    return all_questions

@weave.op
async def predict(question: str):
    return await research_topic(topic=str(question))

@env.task
async def main(datasets: list[str] = ["together-search-bench"], limit: int | None = 1):
    questions = await load_questions(datasets)

    if limit is not None:
        questions = questions[:limit]
        print(f"Limited to {len(questions)} question(s)")

    evaluation = weave.Evaluation(dataset=questions, scorers=[llm_as_a_judge_scoring])
    await evaluation.evaluate(predict)

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(raw_data_path="data").run(main)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/weave_evals.py*

You can run this pipeline locally as follows:

```
export HUGGINGFACE_TOKEN=<> # https://huggingface.co/settings/tokens
export WANDB_API_KEY=<> # https://wandb.ai/settings

uv run weave_evals.py
```

The script will run all tasks in the pipeline and log the evaluation results to Weights & Biases.
While you can also evaluate individual tasks, this script focuses on end-to-end evaluation of the end-to-end deep research workflow.

![Weave evaluations](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/deep-research/weave_evals.png)

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/langgraph-agent-research ===

# LangGraph research agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/langgraph_agent_research).

This tutorial combines [LangGraph](https://langchain-ai.github.io/langgraph/) for agentic control flow with Flyte for durable compute. A research pipeline plans sub-topics, fans out ReAct agents that search the web with [Tavily](https://tavily.com/), synthesizes findings, and loops on quality gaps until the report is good enough. Each LangGraph step dispatches to a separate Flyte task so planning, research, synthesis, and quality checks appear independently in the run UI.

Flyte provides:

- **Per-step tasks** visible in the Flyte UI while LangGraph orchestrates the graph.
- **Secrets** for OpenAI and Tavily API keys.
- **Live HTML reports** with Mermaid graph visualizations and the final synthesized report.

## Define the task environment

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "langgraph>=1.0.7",
#    "langchain-anthropic",
#    "tavily-python",
#    "markdown",
#    "pydantic",
# ]
# main = "research_pipeline"
# params = ""
# ///
import json
import os
import base64
import logging
import markdown

import flyte
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="langgraph-agent-research", pre=True)

env = flyte.TaskEnvironment(
    name="langgraph-agent-research",
    image=main_img,
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    resources=flyte.Resources(cpu=2, memory="2Gi"),
)
# {{/docs-fragment env}}

from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage

from models import TopicReport, QualityResult, PipelineResult
from graph import build_pipeline_graph, build_research_subgraph

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
logging.getLogger("graph").setLevel(logging.INFO)
logging.getLogger("tools.search").setLevel(logging.INFO)

MODEL = "claude-3-5-haiku-latest"

def md_to_html(text: str) -> str:
    """Convert markdown to HTML for Flyte reports."""
    return markdown.markdown(text, extensions=["tables", "fenced_code"])

# ------------------------------------------------------------------
# Flyte tasks — each step is visible in the UI while running
# ------------------------------------------------------------------

@env.task(report=True)
async def plan_topics(query: str, num_topics: int = 3) -> list[str]:
    """Break a research query into focused sub-topics."""
    log.info(f"Planning {num_topics} sub-topics for: {query}")

    await flyte.report.replace.aio(
        f"<h2>Planning</h2><p>Breaking query into {num_topics} sub-topics...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    response = llm.invoke(
        f"Break this research question into exactly {num_topics} focused sub-topics. "
        f"Return ONLY a JSON array of strings, nothing else.\n\nQuestion: {query}"
    )
    try:
        topics = json.loads(response.content)
    except json.JSONDecodeError:
        topics = [query]

    topics = topics[:num_topics]
    log.info(f"Sub-topics: {topics}")

    topic_html = "".join(f"<li>{t}</li>" for t in topics)
    await flyte.report.replace.aio(
        f"<h2>Planning</h2><p>Sub-topics:</p><ul>{topic_html}</ul>"
    )
    await flyte.report.flush.aio()

    return topics

@env.task(report=True)
async def research_topic(topic: str, max_searches: int = 2) -> TopicReport:
    """Run the ReAct research agent on a single sub-topic."""
    log.info(f"[Research Task] Starting: {topic}")

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    tavily_api_key = os.getenv("TAVILY_API_KEY")

    await flyte.report.replace.aio(f"<h2>Researching: {topic}</h2><p>Running searches...</p>")
    await flyte.report.flush.aio()

    graph = build_research_subgraph(
        anthropic_api_key=anthropic_api_key,
        tavily_api_key=tavily_api_key,
        max_searches=max_searches,
        model=MODEL,
    )
    result = await graph.ainvoke({"messages": [HumanMessage(content=f"Research this topic: {topic}")]})
    report = result["messages"][-1].content
    log.info(f"[Research Task] Done: {topic}")

    await flyte.report.replace.aio(f"<h2>{topic}</h2>{md_to_html(report)}")
    await flyte.report.flush.aio()

    return TopicReport(topic=topic, report=report)

@env.task(report=True)
async def synthesize(query: str, results: list[TopicReport]) -> str:
    """Combine sub-topic research reports into a unified synthesis."""
    log.info(f"Synthesizing {len(results)} report(s)...")

    await flyte.report.replace.aio(
        f"<h2>Synthesis</h2><p>Combining {len(results)} reports...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    sections = "\n\n---\n\n".join(
        f"## {r.topic}\n\n{r.report}" for r in results
    )

    response = llm.invoke(
        f"You have research reports on sub-topics of this question:\n\n{query}\n\n"
        f"Sub-topic reports:\n\n{sections}\n\n"
        f"Write a comprehensive report that synthesizes all findings. "
        f"Organize by theme, highlight connections between sub-topics, "
        f"and end with key takeaways."
    )
    synthesis = response.content
    log.info(f"Synthesis complete: {len(synthesis)} chars")

    await flyte.report.replace.aio(f"<h2>Synthesis</h2>{md_to_html(synthesis)}")
    await flyte.report.flush.aio()

    return synthesis

@env.task(report=True)
async def quality_check(query: str, synthesis: str) -> QualityResult:
    """Evaluate report quality and identify gaps."""
    log.info("Evaluating quality...")

    await flyte.report.replace.aio(
        "<h2>Quality Check</h2><p>Evaluating report quality...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    response = llm.invoke(
        f'Evaluate this research report for the question: {query}\n\n'
        f'Report:\n{synthesis}\n\n'
        f'Rate the report quality from 1-10 and identify any gaps or missing perspectives. '
        f'Return JSON: {{"score": <int>, "gaps": [<string>, ...]}}\n'
        f'If the report is comprehensive (score >= 8) or there are no significant gaps, '
        f'return an empty gaps list.'
    )

    try:
        evaluation = json.loads(response.content)
        score = evaluation.get("score", 8)
        gaps = evaluation.get("gaps", [])
    except json.JSONDecodeError:
        score = 8
        gaps = []

    result = QualityResult(score=score, gaps=gaps)
    log.info(f"Score: {result.score}/10, Gaps: {len(result.gaps)}")

    gap_html = "".join(f"<li>{g}</li>" for g in result.gaps) if result.gaps else "<li>None</li>"
    await flyte.report.replace.aio(
        f"<h2>Quality Check</h2>"
        f"<p><b>Score:</b> {result.score}/10</p>"
        f"<p><b>Gaps:</b></p><ul>{gap_html}</ul>"
    )
    await flyte.report.flush.aio()

    return result

# ------------------------------------------------------------------
# Orchestrator: runs the LangGraph pipeline, backed by Flyte tasks
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def research_pipeline(
    query: str,
    num_topics: int = 3,
    max_searches: int = 2,
    max_iterations: int = 2,
) -> PipelineResult:
    """
    Research pipeline workflow:
    1. LangGraph plans sub-topics via plan_topics Flyte task
    2. LangGraph fans out research via Send → each dispatches to research_topic Flyte task
    3. LangGraph synthesizes results via synthesize Flyte task
    4. LangGraph evaluates quality via quality_check Flyte task
    5. If gaps found, loops back to step 2
    """
    log.info(f"Starting research pipeline: {query}")

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    tavily_api_key = os.getenv("TAVILY_API_KEY")

    # Build the pipeline graph, passing all Flyte tasks as compute backends
    pipeline = build_pipeline_graph(
        anthropic_api_key=anthropic_api_key,
        tavily_api_key=tavily_api_key,
        plan_task=plan_topics,
        research_task=research_topic,
        synthesize_task=synthesize,
        quality_check_task=quality_check,
        model=MODEL,
    )

    # Visualize the graphs in report tabs
    graph_tab = flyte.report.get_tab("Agent Graphs")

    png_bytes = pipeline.get_graph().draw_mermaid_png()
    img_b64 = base64.b64encode(png_bytes).decode()
    graph_tab.log(f"""\
<h2>Research Pipeline</h2>\
<img src="data:image/png;base64,{img_b64}" alt="Research pipeline" />""")

    subgraph = build_research_subgraph(anthropic_api_key, tavily_api_key, max_searches, model=MODEL)
    sub_png = subgraph.get_graph().draw_mermaid_png()
    sub_b64 = base64.b64encode(sub_png).decode()
    graph_tab.log(f"""\
<h2>Research Agent (ReAct)</h2>\
<img src="data:image/png;base64,{sub_b64}" alt="ReAct research agent" />""")
    await flyte.report.flush.aio()

    # Run the pipeline — LangGraph controls the flow, Flyte tasks run the compute
    result = await pipeline.ainvoke({
        "query": query,
        "num_topics": num_topics,
        "max_searches": max_searches,
        "max_iterations": max_iterations,
        "iteration": 0,
        "topics": [],
        "research_results": [],
        "synthesis": "",
        "score": 0,
        "gaps": [],
        "final_report": "",
    })

    # Build the final report
    final_report = result["final_report"]
    sub_reports = [TopicReport(**r) for r in result["research_results"]]
    score = result.get("score", 0)
    iteration = result.get("iteration", 1) - 1

    await flyte.report.replace.aio(f"""\
<h2>Research Report</h2>\
<p><b>Query:</b> {query}</p>\
<p><b>Quality:</b> {score}/10 after {iteration} iteration(s)</p>\
<hr/>{md_to_html(final_report)}""")
    await flyte.report.flush.aio()

    log.info(f"Research pipeline complete. Score: {score}/10, Iterations: {iteration}")
    return PipelineResult(
        query=query,
        report=final_report,
        sub_reports=sub_reports,
        score=score,
        iterations=iteration,
    )

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(research_pipeline(query="Compare quantum computing approaches"))
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/langgraph_agent_research/langgraph_agent_research.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "langgraph>=1.0.7",
#    "langchain-openai",
#    "tavily-python",
#    ...
# ]
# ///
```

## Orchestrate the pipeline

The `research_pipeline` task builds the LangGraph workflow, renders graph diagrams in a report tab, and runs the full plan → research → synthesize → quality-check loop.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "langgraph>=1.0.7",
#    "langchain-anthropic",
#    "tavily-python",
#    "markdown",
#    "pydantic",
# ]
# main = "research_pipeline"
# params = ""
# ///
import json
import os
import base64
import logging
import markdown

import flyte
import flyte.report

# {{docs-fragment env}}
main_img = flyte.Image.from_uv_script(__file__, name="langgraph-agent-research", pre=True)

env = flyte.TaskEnvironment(
    name="langgraph-agent-research",
    image=main_img,
    secrets=[
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    resources=flyte.Resources(cpu=2, memory="2Gi"),
)
# {{/docs-fragment env}}

from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage

from models import TopicReport, QualityResult, PipelineResult
from graph import build_pipeline_graph, build_research_subgraph

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
logging.getLogger("graph").setLevel(logging.INFO)
logging.getLogger("tools.search").setLevel(logging.INFO)

MODEL = "claude-3-5-haiku-latest"

def md_to_html(text: str) -> str:
    """Convert markdown to HTML for Flyte reports."""
    return markdown.markdown(text, extensions=["tables", "fenced_code"])

# ------------------------------------------------------------------
# Flyte tasks — each step is visible in the UI while running
# ------------------------------------------------------------------

@env.task(report=True)
async def plan_topics(query: str, num_topics: int = 3) -> list[str]:
    """Break a research query into focused sub-topics."""
    log.info(f"Planning {num_topics} sub-topics for: {query}")

    await flyte.report.replace.aio(
        f"<h2>Planning</h2><p>Breaking query into {num_topics} sub-topics...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    response = llm.invoke(
        f"Break this research question into exactly {num_topics} focused sub-topics. "
        f"Return ONLY a JSON array of strings, nothing else.\n\nQuestion: {query}"
    )
    try:
        topics = json.loads(response.content)
    except json.JSONDecodeError:
        topics = [query]

    topics = topics[:num_topics]
    log.info(f"Sub-topics: {topics}")

    topic_html = "".join(f"<li>{t}</li>" for t in topics)
    await flyte.report.replace.aio(
        f"<h2>Planning</h2><p>Sub-topics:</p><ul>{topic_html}</ul>"
    )
    await flyte.report.flush.aio()

    return topics

@env.task(report=True)
async def research_topic(topic: str, max_searches: int = 2) -> TopicReport:
    """Run the ReAct research agent on a single sub-topic."""
    log.info(f"[Research Task] Starting: {topic}")

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    tavily_api_key = os.getenv("TAVILY_API_KEY")

    await flyte.report.replace.aio(f"<h2>Researching: {topic}</h2><p>Running searches...</p>")
    await flyte.report.flush.aio()

    graph = build_research_subgraph(
        anthropic_api_key=anthropic_api_key,
        tavily_api_key=tavily_api_key,
        max_searches=max_searches,
        model=MODEL,
    )
    result = await graph.ainvoke({"messages": [HumanMessage(content=f"Research this topic: {topic}")]})
    report = result["messages"][-1].content
    log.info(f"[Research Task] Done: {topic}")

    await flyte.report.replace.aio(f"<h2>{topic}</h2>{md_to_html(report)}")
    await flyte.report.flush.aio()

    return TopicReport(topic=topic, report=report)

@env.task(report=True)
async def synthesize(query: str, results: list[TopicReport]) -> str:
    """Combine sub-topic research reports into a unified synthesis."""
    log.info(f"Synthesizing {len(results)} report(s)...")

    await flyte.report.replace.aio(
        f"<h2>Synthesis</h2><p>Combining {len(results)} reports...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    sections = "\n\n---\n\n".join(
        f"## {r.topic}\n\n{r.report}" for r in results
    )

    response = llm.invoke(
        f"You have research reports on sub-topics of this question:\n\n{query}\n\n"
        f"Sub-topic reports:\n\n{sections}\n\n"
        f"Write a comprehensive report that synthesizes all findings. "
        f"Organize by theme, highlight connections between sub-topics, "
        f"and end with key takeaways."
    )
    synthesis = response.content
    log.info(f"Synthesis complete: {len(synthesis)} chars")

    await flyte.report.replace.aio(f"<h2>Synthesis</h2>{md_to_html(synthesis)}")
    await flyte.report.flush.aio()

    return synthesis

@env.task(report=True)
async def quality_check(query: str, synthesis: str) -> QualityResult:
    """Evaluate report quality and identify gaps."""
    log.info("Evaluating quality...")

    await flyte.report.replace.aio(
        "<h2>Quality Check</h2><p>Evaluating report quality...</p>"
    )
    await flyte.report.flush.aio()

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    llm = ChatAnthropic(model=MODEL, api_key=anthropic_api_key)

    response = llm.invoke(
        f'Evaluate this research report for the question: {query}\n\n'
        f'Report:\n{synthesis}\n\n'
        f'Rate the report quality from 1-10 and identify any gaps or missing perspectives. '
        f'Return JSON: {{"score": <int>, "gaps": [<string>, ...]}}\n'
        f'If the report is comprehensive (score >= 8) or there are no significant gaps, '
        f'return an empty gaps list.'
    )

    try:
        evaluation = json.loads(response.content)
        score = evaluation.get("score", 8)
        gaps = evaluation.get("gaps", [])
    except json.JSONDecodeError:
        score = 8
        gaps = []

    result = QualityResult(score=score, gaps=gaps)
    log.info(f"Score: {result.score}/10, Gaps: {len(result.gaps)}")

    gap_html = "".join(f"<li>{g}</li>" for g in result.gaps) if result.gaps else "<li>None</li>"
    await flyte.report.replace.aio(
        f"<h2>Quality Check</h2>"
        f"<p><b>Score:</b> {result.score}/10</p>"
        f"<p><b>Gaps:</b></p><ul>{gap_html}</ul>"
    )
    await flyte.report.flush.aio()

    return result

# ------------------------------------------------------------------
# Orchestrator: runs the LangGraph pipeline, backed by Flyte tasks
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@env.task(report=True)
async def research_pipeline(
    query: str,
    num_topics: int = 3,
    max_searches: int = 2,
    max_iterations: int = 2,
) -> PipelineResult:
    """
    Research pipeline workflow:
    1. LangGraph plans sub-topics via plan_topics Flyte task
    2. LangGraph fans out research via Send → each dispatches to research_topic Flyte task
    3. LangGraph synthesizes results via synthesize Flyte task
    4. LangGraph evaluates quality via quality_check Flyte task
    5. If gaps found, loops back to step 2
    """
    log.info(f"Starting research pipeline: {query}")

    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    tavily_api_key = os.getenv("TAVILY_API_KEY")

    # Build the pipeline graph, passing all Flyte tasks as compute backends
    pipeline = build_pipeline_graph(
        anthropic_api_key=anthropic_api_key,
        tavily_api_key=tavily_api_key,
        plan_task=plan_topics,
        research_task=research_topic,
        synthesize_task=synthesize,
        quality_check_task=quality_check,
        model=MODEL,
    )

    # Visualize the graphs in report tabs
    graph_tab = flyte.report.get_tab("Agent Graphs")

    png_bytes = pipeline.get_graph().draw_mermaid_png()
    img_b64 = base64.b64encode(png_bytes).decode()
    graph_tab.log(f"""\
<h2>Research Pipeline</h2>\
<img src="data:image/png;base64,{img_b64}" alt="Research pipeline" />""")

    subgraph = build_research_subgraph(anthropic_api_key, tavily_api_key, max_searches, model=MODEL)
    sub_png = subgraph.get_graph().draw_mermaid_png()
    sub_b64 = base64.b64encode(sub_png).decode()
    graph_tab.log(f"""\
<h2>Research Agent (ReAct)</h2>\
<img src="data:image/png;base64,{sub_b64}" alt="ReAct research agent" />""")
    await flyte.report.flush.aio()

    # Run the pipeline — LangGraph controls the flow, Flyte tasks run the compute
    result = await pipeline.ainvoke({
        "query": query,
        "num_topics": num_topics,
        "max_searches": max_searches,
        "max_iterations": max_iterations,
        "iteration": 0,
        "topics": [],
        "research_results": [],
        "synthesis": "",
        "score": 0,
        "gaps": [],
        "final_report": "",
    })

    # Build the final report
    final_report = result["final_report"]
    sub_reports = [TopicReport(**r) for r in result["research_results"]]
    score = result.get("score", 0)
    iteration = result.get("iteration", 1) - 1

    await flyte.report.replace.aio(f"""\
<h2>Research Report</h2>\
<p><b>Query:</b> {query}</p>\
<p><b>Quality:</b> {score}/10 after {iteration} iteration(s)</p>\
<hr/>{md_to_html(final_report)}""")
    await flyte.report.flush.aio()

    log.info(f"Research pipeline complete. Score: {score}/10, Iterations: {iteration}")
    return PipelineResult(
        query=query,
        report=final_report,
        sub_reports=sub_reports,
        score=score,
        iterations=iteration,
    )

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(research_pipeline(query="Compare quantum computing approaches"))
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/langgraph_agent_research/langgraph_agent_research.py*

Inside each research task, a ReAct subgraph (`graph.py`) uses `@flyte.trace` on Tavily search calls for observability.

## Run the agent

Create secrets for Anthropic and Tavily:

```
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
flyte create secret tavily_api_key <YOUR_TAVILY_API_KEY>
```

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/langgraph_agent_research):

```
cd v2/tutorials/langgraph_agent_research
uv run --script langgraph_agent_research.py
```

Or pass a custom query:

```
flyte run langgraph_agent_research.py research_pipeline --query "Compare quantum computing approaches"
```

Check the **Agent Graphs** report tab for the LangGraph diagram and the main report for the synthesized research output.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/mle-bot ===

# MLE bot: an autonomous ML engineer

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/mle_bot).

You have a dataset and a business question. Today, going from a raw CSV to a trained, evaluated model with a written report takes an ML engineer hours of experimentation: profiling the data, picking algorithms, engineering features, tuning hyperparameters, analyzing results, and iterating. What if you could describe the problem in plain English and let an agent handle the rest?

This tutorial walks you through building exactly that. You'll construct an autonomous ML engineer that takes a problem description and a dataset, designs experiments, runs them on cloud infrastructure, analyzes results, iterates, and produces a report summarizing the best model it found.

## TL;DR

- You'll build an agent that takes a natural language problem description and a CSV file, then produces a trained model and a detailed report comparing the results.
- The LLM reasons over dataset statistics, never raw data. Trusted tools compute statistics in the cloud, and only those statistics reach the LLM.
- LLM-generated orchestration code runs inside Flyte's sandbox: no imports, no network access, no filesystem. It can only call pre-approved tool functions.
- Each tool function runs as a durable Flyte task in the cloud, with retries, observability, and full traceability.

## The problem with LLMs and ML pipelines

If you ask an LLM to "train a model on this dataset," you run into a few issues fast. The LLM might hallucinate sklearn APIs that don't exist. It has no access to real compute, so it can't actually train anything. It runs everything in a single context with no way to handle large datasets. And if something goes wrong, there's no structured way to iterate.

The core tension is that LLMs are genuinely good at reasoning about *what* to try. Given a dataset profile showing class imbalance and temporal structure, a capable model will suggest rolling window features and appropriate resampling strategies. But LLMs are unreliable at *executing* those plans. They generate buggy code, lose track of variable names, and have no way to dispatch real compute.

The solution is to separate the two concerns. Let the LLM handle the planning: which algorithms to try, what feature engineering to apply, which hyperparameters to tune. Then hand the execution to trusted tool functions that run on real infrastructure. The LLM controls *what* happens. The tools control *how*.

Think of it like giving a junior engineer access to a curated set of approved tools and reviewing their work. They can compose those tools in creative ways, but they can't go off-script and install random packages or hit arbitrary endpoints.

## How it works

The agent runs in five phases:

1. **Profile** the dataset using a trusted tool. The tool returns statistics (shape, class balance, feature correlations, missing values). The LLM never touches the raw data.
2. **Design** a batch of experiments. The LLM reads the profile and proposes 2 to 3 experiments, each with an algorithm, hyperparameters, and a feature engineering strategy.
3. **Execute** each experiment in parallel. For each one, the LLM generates Python orchestration code that chains together pre-approved tool functions. That code runs inside a restricted sandbox, and each tool call dispatches as a durable Flyte task on cloud compute.
4. **Analyze** the results. The LLM reviews metrics across experiments, optionally requests targeted data explorations (e.g., "are failures concentrated on specific machines?"), and decides whether to iterate with new experiments.
5. **Produce a report** summarizing the winning model, the experiment journey, and deployment recommendations.

Two things make this work. First, the LLM never sees raw data. The profiling tool runs in the cloud on managed compute and returns only aggregated statistics. This keeps prompt sizes manageable and avoids leaking sensitive data into LLM API calls. Second, the LLM-generated code runs inside Flyte's sandbox where the only thing it can do is call your pre-approved tool functions. More on that shortly.

### What to expect

Here's what an actual run looks like on a synthetic predictive maintenance dataset (175k rows of sensor data from 20 industrial pumps, ~3% failure rate).

In the first iteration, the agent designed three experiments: a logistic regression baseline, an XGBoost model with rolling window features, and a random forest with lag features. After reviewing the results, it decided to continue. It requested two targeted explorations ("do failure cases show meaningfully higher vibration?" and "how do feature-target correlations vary by pump?"), then used those findings to design a second round of experiments with tuned feature engineering and class weighting.

After two iterations and five total experiments, the final rankings looked like this:

| Rank | Experiment | ROC-AUC | F1 | Recall | Precision |
|------|-----------|---------|------|--------|-----------|
| 1 | **Random Forest with Balanced Class Weights** | 0.7983 | 0.4284 | 0.4561 | 0.4038 |
| 2 | XGBoost with Feature Engineering | 0.7847 | 0.4568 | 0.4722 | 0.4425 |
| 3 | Enhanced XGBoost with Focused Feature Engineering | 0.7821 | 0.3565 | 0.4973 | 0.2778 |
| 4 | Random Forest with Lag Features | 0.7651 | 0.5206 | 0.4104 | 0.7116 |
| 5 | Baseline Logistic Regression | 0.7528 | 0.118 | 0.6496 | 0.0649 |

The agent autonomously explored different algorithms, feature strategies, and class imbalance techniques, then ranked everything by ROC-AUC. The full report includes the LLM's reasoning and generated code for every experiment, so you can trace exactly why it chose each approach and what code it wrote to implement it. Since the LLM makes different decisions each run, your results will vary, but the overall pattern (profile, design, execute, analyze, iterate) stays the same.

## Declaring task environments

Before writing any tasks, you need to declare *where* and *how* they run. In Flyte v2, a `TaskEnvironment` bundles together the container image, resource requirements, secrets, and dependencies for a group of tasks.

The MLE Bot uses two environments. One for the ML tools (pandas, sklearn, xgboost) and one for the agent itself (the OpenAI client and the sandbox runtime):

```
"""Flyte TaskEnvironment definitions for mle-bot.

Two environments:
- tool_env: Runs the ML tools (data loading, feature engineering, training, evaluation).
            Has sklearn, xgboost, pandas, numpy, joblib.
- agent_env: Runs the orchestrating agent (OpenAI calls, sandbox orchestration).
             Has openai, pydantic-monty. Depends on tool_env.
"""

# {{docs-fragment environments}}
import flyte

tool_env = flyte.TaskEnvironment(
    "mle-tools",
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    image=(
        flyte.Image.from_debian_base(name="mle-tools-image").with_pip_packages(
            "pandas>=2.0.0",
            "scikit-learn>=1.3.0",
            "xgboost>=2.0.0",
            "numpy>=1.24.0",
            "joblib>=1.3.0",
        )
    ),
)

agent_env = flyte.TaskEnvironment(
    "mle-agent",
    resources=flyte.Resources(cpu=1, memory="2Gi"),
    secrets=[flyte.Secret(key="OPENAI_API_KEY", as_env_var="OPENAI_API_KEY")],
    env_vars={"PYTHONUNBUFFERED": "1"},
    image=(
        flyte.Image.from_debian_base(name="mle-agent-image")
        .with_apt_packages("git")
        .with_pip_packages(
            "openai>=1.0.0",
            "flyte[sandbox]",
        )
    ),
    depends_on=[tool_env],
)
# {{/docs-fragment environments}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/environments.py*

A few things to note. `flyte.Resources` sets the CPU and memory for every task in that environment. `flyte.Image.from_debian_base()` builds a container image on the fly with the packages you declare, so you never need to manage Dockerfiles. `flyte.Secret` injects a secret from your cluster's secret store as an environment variable. And `depends_on=[tool_env]` tells Flyte that the agent environment needs to be able to dispatch tasks in the tool environment. This is what enables the sandbox to call tool functions that run on separate, appropriately-resourced compute.

## Building durable tool functions

Each tool is a regular Python async function decorated with `@env.task`. That decorator turns it into a durable Flyte task: it runs in its own container with the resources declared on the environment, it's automatically retried on transient failures, and every invocation is tracked in the Flyte UI.

Data flows between tasks as `flyte.io.File` objects. A `File` is a reference to data in cloud storage. When a task needs the actual bytes, it calls `await data.download()` to pull them into the container's local filesystem. When it produces output, it creates a `File` from a local path and returns it. Flyte handles the upload to cloud storage when the task completes. The data itself never passes through the agent or the LLM.

Here's what the training tool looks like:

```
"""Model training tools.

A single unified interface for training classifiers with different algorithms.
The tool handles serialization, class imbalance, and basic hyperparameter passing.
"""

from flyte.io import File

from mle_bot.environments import tool_env
from mle_bot.schemas import (
    GradientBoostingParams,
    LogisticRegressionParams,
    RandomForestParams,
    XGBoostParams,
)

# {{docs-fragment train_model}}
@tool_env.task
async def train_model(
    data: File,
    target_column: str,
    algorithm: str,
    hyperparams: dict,
) -> File:
    """Train a classification model and return the serialized model and training metrics.

    Supports multiple algorithms through a single interface so the agent can
    dispatch different approaches without knowing implementation details.

    Args:
        data: CSV file with training data (features + target column).
        target_column: Name of the column to predict.
        algorithm: One of:
            "xgboost"            — Gradient boosted trees. Good default for tabular data.
                                   Handles missing values and class imbalance natively.
            "random_forest"      — Ensemble of decision trees. More robust to outliers.
            "logistic_regression"— Linear model. Fast baseline, good for linearly separable problems.
            "gradient_boosting"  — Sklearn GradientBoostingClassifier. Slower than xgboost
                                   but sometimes better on small datasets.
        hyperparams: Dict of algorithm-specific hyperparameters. Common keys:
            For xgboost / gradient_boosting:
                n_estimators (int, default 100): Number of trees.
                max_depth (int, default 6): Maximum tree depth.
                learning_rate (float, default 0.1): Step size shrinkage.
                scale_pos_weight (float): Ratio negative/positive — use for imbalanced data.
                                          Set to (n_negative / n_positive) to upweight minority class.
                subsample (float, default 1.0): Fraction of samples used per tree.
                colsample_bytree (float, default 1.0): Fraction of features per tree.
            For random_forest:
                n_estimators (int, default 100): Number of trees.
                max_depth (int or null, default null): Maximum tree depth (null = unlimited).
                min_samples_leaf (int, default 1): Minimum samples at a leaf node.
                class_weight (str, default "balanced"): "balanced" reweights by class frequency.
            For logistic_regression:
                C (float, default 1.0): Inverse regularization strength (higher = less regularization).
                max_iter (int, default 1000): Maximum iterations for solver.
                class_weight (str, default "balanced"): "balanced" reweights by class frequency.

    Returns:
        File — serialized model (joblib format, contains model + feature columns + target column).
    """
    # {{/docs-fragment train_model}}
    import tempfile

    import joblib
    import numpy as np
    import pandas as pd
    from flyte.io import File as FlyteFile
    from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

    path = await data.download()
    df = pd.read_csv(path)

    # Only use numeric columns — drop strings like machine_id automatically
    feature_cols = [c for c in df.select_dtypes(include=[np.number]).columns if c != target_column]
    X = df[feature_cols].values
    y = df[target_column].values

    class_dist = {str(k): int(v) for k, v in zip(*np.unique(y, return_counts=True))}
    n_positive = int((y == 1).sum())
    n_negative = int((y == 0).sum())
    default_scale = max(1.0, n_negative / n_positive) if n_positive > 0 else 1.0

    if algorithm == "xgboost":
        from xgboost import XGBClassifier
        p = XGBoostParams.model_validate({**hyperparams, "scale_pos_weight": hyperparams.get("scale_pos_weight", default_scale)})
        params = {**p.model_dump(), "eval_metric": "logloss", "random_state": 42}
        model = XGBClassifier(**params)

    elif algorithm == "random_forest":
        p = RandomForestParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42, "n_jobs": -1}
        model = RandomForestClassifier(**params)

    elif algorithm == "gradient_boosting":
        p = GradientBoostingParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42}
        model = GradientBoostingClassifier(**params)

    elif algorithm == "logistic_regression":
        p = LogisticRegressionParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42}
        model = LogisticRegression(**params)

    else:
        raise ValueError(f"Unknown algorithm: {algorithm!r}. Choose from: xgboost, random_forest, gradient_boosting, logistic_regression")

    model.fit(X, y)
    y_pred = model.predict(X)
    y_prob = model.predict_proba(X)[:, 1] if hasattr(model, "predict_proba") else y_pred

    train_metrics = {
        "accuracy": round(float(accuracy_score(y, y_pred)), 4),
        "f1": round(float(f1_score(y, y_pred, average="binary", zero_division=0)), 4),
        "precision": round(float(precision_score(y, y_pred, average="binary", zero_division=0)), 4),
        "recall": round(float(recall_score(y, y_pred, average="binary", zero_division=0)), 4),
        "roc_auc": round(float(roc_auc_score(y, y_prob)), 4),
    }

    # Feature importance (top 20)
    if hasattr(model, "feature_importances_"):
        importances = model.feature_importances_
        importance_dict = {feature_cols[i]: round(float(importances[i]), 4) for i in range(len(feature_cols))}
        importance_dict = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)[:20])
    elif hasattr(model, "coef_"):
        importances = abs(model.coef_[0])
        importance_dict = {feature_cols[i]: round(float(importances[i]), 4) for i in range(len(feature_cols))}
        importance_dict = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)[:20])
    else:
        importance_dict = {}

    model_file = tempfile.NamedTemporaryFile(suffix=".joblib", delete=False)
    joblib.dump({"model": model, "feature_columns": feature_cols, "target_column": target_column}, model_file.name)
    model_file.close()

    return await FlyteFile.from_local(model_file.name)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/tools/training.py*

And here's the profiling tool, which is the first thing the agent calls. It computes dataset statistics that the LLM will use to design experiments:

```
"""Data loading, profiling, and splitting tools.

These tools are safe, general-purpose, and side-effect free.
They run as durable Flyte tasks so they execute in the cloud on managed compute.
"""

from flyte.io import File

from mle_bot.environments import tool_env

# {{docs-fragment profile_dataset}}
@tool_env.task
async def profile_dataset(data: File, target_column: str) -> dict:
    """Profile a dataset and return statistics that inform ML problem design.

    Call this first before designing any experiments. The returned profile tells
    you the shape, column types, class balance, missing values, and numeric
    statistics — everything needed to choose algorithms and feature strategies.

    Args:
        data: CSV file to profile.
        target_column: Name of the column to predict.

    Returns a dict with keys:
        - shape: [n_rows, n_cols]
        - columns: list of all column names
        - dtypes: {col: dtype_string, ...}
        - numeric_columns: list of numeric column names (excluding target)
        - categorical_columns: list of non-numeric column names (excluding target)
        - target_distribution: {class_value: count, ...}
        - class_balance: {class_value: pct, ...}  (proportions, sum to 100)
        - missing_pct: {col: pct_missing, ...}
        - numeric_stats: {col: {mean, std, min, max, median}, ...}
        - n_classes: int — number of unique target values
        - is_imbalanced: bool — True if minority class < 20% of data
        - sample: list of 3 example rows as dicts
    """
    import numpy as np
    import pandas as pd

    path = await data.download()
    df = pd.read_csv(path)

    target_counts = df[target_column].value_counts()
    class_balance = (df[target_column].value_counts(normalize=True) * 100).round(2).to_dict()
    minority_pct = float(min(class_balance.values()))

    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    categorical_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()

    numeric_stats = {}
    for col in numeric_cols:
        if col == target_column:
            continue
        numeric_stats[col] = {
            "mean": round(float(df[col].mean()), 4),
            "std": round(float(df[col].std()), 4),
            "min": round(float(df[col].min()), 4),
            "max": round(float(df[col].max()), 4),
            "median": round(float(df[col].median()), 4),
        }

    # Point-biserial correlation between each numeric feature and the target
    feature_target_corr = {}
    for col in numeric_cols:
        if col == target_column:
            continue
        corr = float(df[col].corr(df[target_column]))
        if not np.isnan(corr):
            feature_target_corr[col] = round(corr, 4)
    # Sort by absolute correlation descending
    feature_target_corr = dict(
        sorted(feature_target_corr.items(), key=lambda x: abs(x[1]), reverse=True)
    )

    return {
        "shape": list(df.shape),
        "columns": list(df.columns),
        "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
        "numeric_columns": [c for c in numeric_cols if c != target_column],
        "categorical_columns": [c for c in categorical_cols if c != target_column],
        "target_distribution": {str(k): int(v) for k, v in target_counts.items()},
        "class_balance": {str(k): float(v) for k, v in class_balance.items()},
        "missing_pct": {col: round(float(pct * 100), 2) for col, pct in df.isnull().mean().items()},
        "numeric_stats": numeric_stats,
        "feature_target_corr": feature_target_corr,
        "n_classes": int(df[target_column].nunique()),
        "is_imbalanced": minority_pct < 20.0,
        "sample": df.head(3).fillna("").to_dict(orient="records"),
    }
# {{/docs-fragment profile_dataset}}

@tool_env.task
async def split_dataset(
    data: File,
    target_column: str,
    test_size: float,
    time_column: str,
    split_type: str,
) -> File:
    """Split a dataset and return either the train or test half.

    Call this twice — once with split_type="train" and once with split_type="test" —
    to get both halves. Always split before feature engineering to prevent data leakage.

    Args:
        data: CSV file to split.
        target_column: Name of the column to predict.
        test_size: Fraction of data to use for testing (e.g. 0.2 for 20%).
        time_column: If non-empty, sort by this column and take the last
                     `test_size` fraction as test (time-based split, no shuffling).
                     If empty string "", use stratified random split.
        split_type: Which half to return — "train" or "test".

    Returns:
        File — CSV file containing the requested split (train or test rows).
    """
    import tempfile

    import pandas as pd
    from flyte.io import File as FlyteFile
    from sklearn.model_selection import train_test_split

    path = await data.download()
    df = pd.read_csv(path)

    if time_column:
        df = df.sort_values(time_column).reset_index(drop=True)
        split_idx = int(len(df) * (1 - test_size))
        train_df = df.iloc[:split_idx]
        test_df = df.iloc[split_idx:]
    else:
        train_df, test_df = train_test_split(
            df,
            test_size=test_size,
            stratify=df[target_column],
            random_state=42,
        )

    selected_df = train_df if split_type == "train" else test_df

    out = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
    selected_df.to_csv(out.name, index=False)
    out.close()
    return await FlyteFile.from_local(out.name)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/tools/data.py*

The full tool inventory includes ten functions: `profile_dataset`, `split_dataset`, `explore_dataset`, `engineer_features`, `select_features`, `resample_dataset`, `train_model`, `get_predictions`, `evaluate_model`, and `rank_experiments`. Each one does exactly one thing. The LLM composes them into pipelines, but each tool enforces its own correctness guarantees internally. For example, `resample_dataset` only applies resampling to training data, never test data, regardless of what the LLM asks for.

## Guiding the LLM with domain knowledge

The quality of the agent's experiments depends heavily on what you tell it. The MLE Bot bakes ML best practices directly into its system prompts, so the LLM starts from a solid foundation rather than relying on whatever it picked up during pretraining.

The orchestration prompt, for example, includes guidance on feature engineering strategies, class imbalance handling, and algorithm selection. It's dynamically built from the dataset profile, so the LLM sees concrete context alongside the general advice:

```python
def _build_orchestration_system_prompt(profile: dict) -> str:
    return f"""\
You are an expert ML engineer. Your job is to design and write the best
possible pipeline for a machine learning experiment.

## Dataset context
Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw): {corr_str}

## General ML best practices
**Feature engineering**:
- Sequential/time-series data: rolling window features capture trends
  that point-in-time readings miss. Choose window sizes relative to
  the prediction horizon and temporal resolution of the data.
- Consider skipping feature engineering entirely for a baseline.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight.
- The default 0.5 decision threshold may not be optimal.

**Algorithm selection**:
- XGBoost: strong default for tabular data. Start here.
- RandomForest: more robust to outliers, good for noisy data.
- LogisticRegression: fast linear baseline.
...
"""
```

This means the LLM doesn't just get a blank canvas. It gets a structured briefing that combines the actual dataset characteristics with best practices for handling them. When the profile shows class imbalance, the prompt tells it which hyperparameters to adjust and which resampling strategies to consider. When there's a timestamp column, the prompt suggests rolling window features with guidance on window sizing.

The user's problem description also has a significant impact on the agent's behavior. A query like "Predict pump failures 24 hours before they happen based on sensor readings" tells the LLM that this is a time-series classification problem with a specific prediction horizon. That shapes everything: the LLM will favor temporal feature engineering (rolling windows sized relative to that 24-hour horizon), pick algorithms that handle imbalanced binary classification well, and focus on recall as a key metric because missing a failure is worse than a false alarm. Change the query to something like "Classify machine health status from the latest sensor snapshot" and the same dataset would produce a completely different set of experiments, with less emphasis on temporal features and more on cross-sectional patterns.

## The agent loop: profile, design, execute, iterate

The agent's main function orchestrates five phases. Let's walk through each one.

**Phase 1: Profile.** The agent calls `profile_dataset` directly as a trusted tool. This isn't sandboxed because there's nothing to protect against here: the function is your code, running on your compute. The `flyte.group` call organizes this step in the Flyte UI so you can inspect it later.

```python
with flyte.group("profile"):
    profile = await profile_dataset(data, target_column)
```

**Phase 2: Design.** The profile dict goes to the LLM along with the problem description. The LLM returns a structured response matching the `InitialDesign` schema:

```
"""Pydantic schemas for tool inputs and agent data structures.

These models define the expected shape of configs and results throughout the agent.

Important: Tool functions that are called from the Monty sandbox must accept plain
`dict` at the boundary (Monty can't import or instantiate classes). Each tool parses
its incoming dict into the appropriate model internally for validation. In agent.py,
use `.model_dump()` to convert models back to dicts before passing to the sandbox.
"""

from typing import Literal

from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Feature engineering
# ---------------------------------------------------------------------------

class FeatureConfig(BaseModel):
    """Configuration for the engineer_features tool."""

    group_column: str = Field(
        default="",
        description="Column to group by for rolling/lag features (e.g. 'machine_id'). "
                    "Required when rolling_columns or lag_columns is specified.",
    )
    time_column: str = Field(
        default="",
        description="Timestamp column to sort by before computing rolling/lag features.",
    )
    rolling_columns: list[str] = Field(
        default_factory=list,
        description="Numeric columns to compute rolling statistics for (mean, std, min, max).",
    )
    windows: list[int] = Field(
        default_factory=list,
        description="Rolling window sizes in rows (e.g. [6, 12, 24]).",
    )
    lag_columns: list[str] = Field(
        default_factory=list,
        description="Numeric columns to create lag features for.",
    )
    lags: list[int] = Field(
        default_factory=list,
        description="Lag steps in rows (e.g. [1, 3, 6]).",
    )
    normalize: bool = Field(
        default=False,
        description="If true, z-score normalize all numeric columns except target_column.",
    )
    target_column: str = Field(
        default="",
        description="Column to exclude from normalization. Required when normalize=True.",
    )
    drop_columns: list[str] = Field(
        default_factory=list,
        description="Columns to remove from output (e.g. raw timestamp after rolling).",
    )
    fillna_method: Literal["forward", "zero", "drop"] = Field(
        default="forward",
        description="How to fill NaN values introduced by rolling/lag. "
                    "'forward' forward-fills then fills remaining with 0. "
                    "'zero' fills all NaN with 0. 'drop' drops rows with NaN.",
    )

# ---------------------------------------------------------------------------
# Training hyperparameters (per algorithm)
# ---------------------------------------------------------------------------

class XGBoostParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int = Field(default=6, ge=1, le=20)
    learning_rate: float = Field(default=0.1, gt=0, le=1)
    scale_pos_weight: float = Field(
        default=1.0, ge=0,
        description="Set to n_negative/n_positive for imbalanced datasets.",
    )
    subsample: float = Field(default=1.0, gt=0, le=1)
    colsample_bytree: float = Field(default=1.0, gt=0, le=1)

class RandomForestParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int | None = Field(
        default=None,
        description="Maximum tree depth. None means unlimited.",
    )
    min_samples_leaf: int = Field(default=1, ge=1)
    class_weight: Literal["balanced", "balanced_subsample"] | None = Field(default="balanced")

class GradientBoostingParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int = Field(default=3, ge=1, le=10)
    learning_rate: float = Field(default=0.1, gt=0, le=1)
    subsample: float = Field(default=1.0, gt=0, le=1)

class LogisticRegressionParams(BaseModel):
    C: float = Field(default=1.0, gt=0, description="Inverse regularization strength.")
    max_iter: int = Field(default=1000, ge=100)
    class_weight: Literal["balanced"] | None = Field(default="balanced")

# ---------------------------------------------------------------------------
# Experiment design (used by agent.py, validated when parsing LLM JSON)
# ---------------------------------------------------------------------------

Algorithm = Literal["xgboost", "random_forest", "gradient_boosting", "logistic_regression"]

# {{docs-fragment schemas}}
class ExperimentConfig(BaseModel):
    """One experiment to run — produced by the LLM and executed by the agent."""

    name: str = Field(description="Short descriptive name for this experiment.")
    algorithm: Algorithm
    hyperparams: dict = Field(
        default_factory=dict,
        description="Algorithm-specific hyperparameters. Will be validated inside train_model.",
    )
    feature_config: FeatureConfig = Field(default_factory=FeatureConfig)
    rationale: str = Field(default="", description="Why this experiment is worth running.")

class InitialDesign(BaseModel):
    """LLM response for initial experiment design."""

    problem_type: str = Field(default="binary_classification")
    primary_metric: Literal["roc_auc", "f1", "recall"] = Field(default="roc_auc")
    reasoning: str
    experiments: list[ExperimentConfig]

class IterationDecision(BaseModel):
    """LLM response after analyzing experiment results."""

    should_continue: bool
    reasoning: str
    exploration_requests: list[dict] = Field(
        default_factory=list,
        description="Optional list of explore_dataset config dicts to run before designing "
                    "the next batch. Each dict is passed directly to explore_dataset.",
    )
    next_experiments: list[ExperimentConfig] = Field(default_factory=list)
# {{/docs-fragment schemas}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/schemas.py*

The LLM typically proposes 2 to 3 experiments: a baseline with minimal feature engineering, an experiment with rolling window features for temporal data, and perhaps one testing a different algorithm or resampling strategy.

**Phase 3: Execute in parallel.** All experiments in a batch run simultaneously using `asyncio.gather()`. Each experiment dispatches its own set of durable Flyte tasks:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

**Phase 4: Analyze and iterate.** After each batch completes, the LLM reviews the results and decides whether to continue. It can optionally request targeted data explorations before designing the next round. If the LLM requests explorations (e.g., "do failure cases show higher vibration readings?"), the agent runs `explore_dataset` with those configurations, feeds the results back to the LLM, and lets it refine the next batch of experiments based on what it learned. The loop continues until the LLM decides to stop, the target metric threshold is reached, or the maximum number of iterations is exhausted.

## Running LLM-generated code in Flyte's sandbox

This is where it gets interesting. The LLM doesn't just pick parameters from a dropdown. For each experiment, it writes actual Python code that decides how to compose the tool functions into a pipeline. Maybe it splits the data, engineers rolling window features, applies SMOTE resampling on the training split, trains an XGBoost model, and evaluates it. Or maybe it skips feature engineering entirely for a baseline. The LLM decides the structure.

That code runs inside Flyte's sandbox, a restricted execution environment that enforces strict constraints:

- No `import` statements. The only callable functions are the ones you explicitly provide.
- No network access and no filesystem access.
- No `try`/`except`, no `class` definitions, no augmented assignment (`+=`).
- No `with` statements, no generators, no `global`/`nonlocal`.

The sandbox sees your pre-approved tool functions as plain function calls. When the code calls `train_model(...)`, the sandbox pauses execution, dispatches the call to Flyte (which runs it as a durable task on cloud compute with the resources declared on `tool_env`), waits for the result, and resumes. The LLM-generated code looks like synchronous Python, but under the hood each tool call is a full Flyte task execution.

Here's how the sandbox is invoked:

```python
result = await flyte.sandbox.orchestrate_local(
    code,
    inputs={
        "data": data,
        "target_column": target_column,
        "time_column": time_column,
        "experiment_name": exp_name,
    },
    tasks=TOOLS,
)
```

The `code` parameter is a string of Python generated by the LLM. `inputs` provides the variables that the code can reference. `tasks` is the allowlist: a list of Flyte task functions that the code is permitted to call. Nothing else is available.

Here's an example of what the LLM might generate for a single experiment:

```python
train_file = split_dataset(data, target_column, 0.2, time_column, "train")
test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

eng_train = engineer_features(train_file, {
    "rolling_columns": ["vibration_mms", "temperature_c"],
    "windows": [6, 12, 24],
    "group_column": "machine_id",
    "time_column": "timestamp"
})
eng_test = engineer_features(test_file, {
    "rolling_columns": ["vibration_mms", "temperature_c"],
    "windows": [6, 12, 24],
    "group_column": "machine_id",
    "time_column": "timestamp"
})

model_file = train_model(eng_train, target_column, "xgboost", {
    "n_estimators": 200, "max_depth": 8, "scale_pos_weight": 33
})
eval_result = evaluate_model(model_file, eng_test, target_column)

{"experiment_name": experiment_name, "algorithm": "xgboost",
 "metrics": eval_result["metrics"],
 "confusion_matrix": eval_result["confusion_matrix"],
 "threshold_analysis": eval_result["threshold_analysis"],
 "n_samples": eval_result["n_samples"]}
```

Each function call in that snippet dispatches a separate Flyte task. The `split_dataset` calls run on the tool environment's compute (2 CPU, 4Gi memory). The `train_model` call trains an actual XGBoost model. The last expression (a dict literal) is returned as the sandbox result.

Sometimes the LLM generates code with bugs, like a wrong variable name or a missing argument. The agent handles this with a retry loop. If the sandbox raises an exception, the error message and the failing code are fed back to the LLM, which gets a chance to fix the issue:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

On the first attempt, `previous_error` and `previous_code` are empty. On subsequent attempts, the LLM sees exactly what went wrong and can fix it. In practice, most experiments succeed on the first try, with occasional recoveries on the second.

## Streaming results to a live report

While the agent runs, it streams results to the Flyte UI in real time using `flyte.report.log.aio()`. You don't have to wait for the full run to finish to see how experiments are performing.

The entrypoint task enables this with `report=True`:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

As each experiment completes, the agent streams its metrics to the report:

```python
await flyte.report.log.aio(
    f"<h3>Iteration {iteration + 1}: {exp_result.name}</h3>"
    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
    f"<b>F1:</b> {m.get('f1')}</p>",
    do_flush=True,
)
```

The final report includes a dataset summary, per-experiment metrics with expandable reasoning and generated code, the analysis decisions at each iteration, a final rankings table, and an experiment journey summary showing how the agent's strategy evolved.

## Running it

First, generate the synthetic demo dataset (a predictive maintenance scenario with 175k+ rows of simulated sensor data from 20 industrial pumps):

```bash
uv run main.py generate-data
```

Then submit the agent to your Flyte cluster:

```bash
uv run main.py run \
    --data data/predictive_maintenance.csv \
    --problem "Predict pump failures 24 hours before they happen" \
    --target failure_24h \
    --time-column timestamp \
    --max-iterations 3 \
    --output results/report.md
```

The agent connects to your cluster via `~/.flyte/config.yaml`, uploads the CSV, and submits the agent task. You'll see a URL to track the execution in the Flyte UI, and logs will stream to your terminal.

> [!NOTE]
> You'll need to register your OpenAI API key as a cluster secret before running:
> `flyte create secret openai-api-key <YOUR_KEY>`

If you want to see the self-healing retry loop in action, add the `--inject-failure` flag. This deliberately corrupts the first experiment so the agent has to detect the error and recover, which makes for a nice demo of the durability guarantees.

## Why Flyte?

You could build something similar with plain Python and `exec()`. But there are a few things you'd lose.

**Safety.** Flyte's sandbox restricts LLM-generated code to calling your pre-approved functions and nothing else. No imports, no network, no filesystem. If you wouldn't give an intern root access to your production cluster, you probably shouldn't give an LLM unrestricted code execution either.

**Durability.** Every tool call is a Flyte task. If the agent process crashes halfway through iteration 3, the experiments that already completed are cached. You restart and pick up where you left off instead of retraining models from scratch. For long-running ML experiments, this matters.

**Observability.** You can see every LLM prompt, every generated code snippet, every tool invocation, and every result in the Flyte UI. When the agent makes a questionable decision (like skipping feature engineering on temporal data), you can trace exactly why: the prompt it received, the profile it read, the reasoning it generated.

**Compute isolation.** The ML tools run on cloud instances with the CPU and memory they need. The agent itself runs on a small 1-CPU instance since all it does is call the LLM and dispatch tool tasks. You're not bottlenecked by your laptop, and you're not paying for GPU-class compute to run an orchestration loop.

**Parallelism.** Multiple experiments run simultaneously via `asyncio.gather()`, each dispatching its own durable tasks. Flyte handles the scheduling. If you have three experiments in a batch and each involves training + evaluation, that's six tasks running concurrently on cloud compute.

The MLE Bot is a specific example of a more general pattern: giving an LLM the ability to reason about *what* work should be done, while Flyte handles *how* that work gets executed safely, durably, and at scale. The sandbox is the boundary between the two. Everything above the boundary is LLM-generated and untrusted. Everything below it is your code, running on your infrastructure, with all the guarantees you'd expect from a production orchestrator.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/compliance-monitoring-agent ===

# Compliance monitoring agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/compliance_monitoring_agent).

This example demonstrates how to build a regulatory and compliance monitoring agent on Flyte. The agent watches trusted regulatory sources — FDA guidance, SEC filings, sanctions lists, state-level privacy laws — and routes structured, **citation-precise** findings to the right downstream team (compliance, legal, or clinical ops).

Compliance monitoring requires **citation precision and recency** so every finding can be verified. The [You.com Research API](https://you.com/docs/research/overview) returns a grounded, synthesized answer plus structured sources (URL, title, snippet). Use `source_control` to restrict research to trusted government and regulator domains within a recency window, and `output_schema` when you need machine-readable findings. [Claude](https://docs.anthropic.com/) via [LiteLLM](https://docs.litellm.ai/) triages each finding for severity and routing. Combined with Flyte's audit lineage, you get end-to-end traceability from query to citation.

Flyte provides:

- **Fan-out parallelism** across watch items
- **`@flyte.trace`** on every You.com Research and LLM call
- **Retries** on monitoring tasks for robustness
- **Flyte reports** grouped by team and severity

![Compliance monitoring agent report](https://www.union.ai/docs/v2/union/_static/images/tutorials/compliance_monitoring_agent/compliance-monitoring-agent.png)

## Setting up the environment

The agent runs in a `TaskEnvironment` with secrets for the You.com and Anthropic API keys and a container image built from the `uv` script dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# ///
```

## Data types

Each `WatchItem` specifies a regulatory topic, a list of trusted domains for `source_control`, and a routing destination team. Findings carry citation metadata — source URL, published date, and snippet — so every claim can be verified.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

## Research with the You.com Research API

The `you_research` helper calls the [You.com Research API](https://you.com/docs/research/overview) at `https://api.you.com/v1/research`. It passes `source_control` with an `include_domains` allowlist and a `freshness` filter, and requests structured output via `output_schema`.

See the [Research API reference](https://you.com/docs/api-reference/research/v1-research) for `research_effort` levels (`lite`, `standard`, `deep`, `exhaustive`), `source_control`, and `output_schema` parameters.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

## Triage findings with Claude

After the Research API returns structured findings, Claude assigns a severity (`info`, `watch`, or `action`) and a routing rationale for each one.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

## Monitor one watch item

The `monitor_watch_item` task researches a single regulatory topic, enriches findings with source metadata from the Research API response, and triages each finding for severity and routing.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

## Orchestration

The `compliance_monitoring` driver task fans out across all watch items, aggregates findings, and renders a Flyte report sorted by severity and team.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "compliance_monitoring"
# params = ""
# ///
"""Regulatory & compliance monitoring agent.

Watches trusted regulatory sources via the You.com Research API (with
domain/freshness source controls and a structured output schema), then uses
Claude to assign severity and route citation-precise findings to the right team.
Every external call is traced so Flyte's audit lineage extends to the web layer.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="compliance-monitoring",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="compliance-monitoring", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class WatchItem:
    topic: str
    trusted_domains: list[str]
    team: str

@dataclass
class Finding:
    topic: str
    team: str
    title: str
    summary: str
    source_url: str
    published_date: str
    snippet: str
    domain: str = ""
    favicon: str = ""
    severity: str = "info"
    rationale: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class ComplianceReport:
    findings: list[Finding] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

FINDINGS_SCHEMA = {
    "type": "object",
    "properties": {
        "findings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "summary": {"type": "string"},
                    "source_url": {"type": "string"},
                    "published_date": {"type": "string"},
                    "snippet": {"type": "string"},
                },
                "required": [
                    "title",
                    "summary",
                    "source_url",
                    "published_date",
                    "snippet",
                ],
                "additionalProperties": False,
            },
        }
    },
    "required": ["findings"],
    "additionalProperties": False,
}

async def _you_post(url: str, body: dict, timeout: float = 300.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(
    question: str,
    include_domains: list[str],
    freshness: str,
    research_effort: str = "standard",
) -> dict:
    """Call the You.com Research API with domain + freshness source controls."""
    body = {
        "input": question,
        "research_effort": research_effort,
        "source_control": {
            "include_domains": include_domains,
            "freshness": freshness,
        },
        "output_schema": FINDINGS_SCHEMA,
    }
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment llm}}
@flyte.trace
async def triage(topic: str, findings: list[dict]) -> list[dict]:
    """Use Claude to assign a severity + rationale to each finding."""
    from litellm import acompletion

    if not findings:
        return []

    system = (
        "You are a regulatory-compliance triage analyst. For each finding, "
        "assign a severity of 'info' (FYI), 'watch' (monitor closely), or "
        "'action' (requires a concrete compliance/legal response), and a one-"
        "sentence rationale. Respond ONLY with JSON: "
        '{"triage": [{"severity": str, "rationale": str}]} with one entry per '
        "finding, in order."
    )
    listing = "\n".join(
        f"[{i + 1}] {f.get('title', '')}: {f.get('summary', '')}"
        for i, f in enumerate(findings)
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": f"Topic: {topic}\n\nFindings:\n{listing}"},
        ],
        temperature=0.0,
        max_tokens=1024,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed.get("triage", []) if isinstance(parsed, dict) else []

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

# {{docs-fragment monitor_watch_item}}
@env.task(retries=3)
async def monitor_watch_item(item: WatchItem, freshness: str) -> list[Finding]:
    """Research one regulatory topic and produce triaged, cited findings."""
    question = (
        f"What are the most recent changes, updates, or new guidance regarding "
        f"'{item.topic}'? Report concrete, dated changes with their sources."
    )
    result = await you_research(question, item.trusted_domains, freshness)
    output = result.get("output", {})

    # Build a lookup from the Research API's full source list (url -> metadata).
    src_by_url: dict[str, dict] = {}
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        if url:
            src_by_url[url] = s

    content = output.get("content", {})
    if isinstance(content, str):
        content = _parse_json(content) if content.strip() else {}
    raw_findings = content.get("findings", []) if isinstance(content, dict) else []

    triage_results = await triage(item.topic, raw_findings)

    findings: list[Finding] = []
    for i, f in enumerate(raw_findings):
        t = triage_results[i] if i < len(triage_results) else {}
        url = str(f.get("source_url", ""))
        meta = src_by_url.get(url, {})
        snippet = str(f.get("snippet", "")) or str((meta.get("snippets") or [""])[0])
        findings.append(
            Finding(
                topic=item.topic,
                team=item.team,
                title=str(f.get("title", "") or meta.get("title", "")),
                summary=str(f.get("summary", "")),
                source_url=url,
                published_date=str(f.get("published_date", "")),
                snippet=snippet,
                domain=_domain(url),
                favicon=_favicon_for(url),
                severity=str(t.get("severity", "info")),
                rationale=str(t.get("rationale", "")),
            )
        )
    return findings
# {{/docs-fragment monitor_watch_item}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"action": 0, "watch": 1, "info": 2}
_SEVERITY_STYLE = {
    "action": ("#fdecea", "#c0392b"),
    "watch": ("#fdf3e1", "#b7791f"),
    "info": ("#e3f1fb", "#2b6cb0"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.action { border-left-color:#c0392b; }
  .rpt .card.watch { border-left-color:#b7791f; }
  .rpt .card.info { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .team { display:inline-block; font-size:11px; font-weight:600;
               padding:3px 9px; border-radius:6px; background:#edf0f3;
               color:#52606d; text-transform:uppercase; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .rationale { font-size:13px; color:#486581; font-style:italic; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:8px 10px;
               margin-top:10px; }
  .rpt .cite img.fav { width:16px; height:16px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.45; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:3px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#edf0f3", "#52606d"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(f: Finding) -> str:
    """Render a rich You.com Research citation with domain, date, and snippet."""
    if not f.source_url:
        return ""
    meta = f.published_date[:10] if f.published_date else ""
    snip = f"<div class='csnip'>&ldquo;{f.snippet}&rdquo;</div>" if f.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{f.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{f.source_url}'><span class='cdom'>{f.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{meta} &middot; {f.title}</div>{snip}</div></div>"
    )

def _render_report(report: ComplianceReport) -> str:
    findings = sorted(
        report.findings,
        key=lambda f: (_SEVERITY_ORDER.get(f.severity, 3), f.team),
    )
    counts = {s: sum(1 for f in findings if f.severity == s) for s in _SEVERITY_ORDER}
    cited = sum(1 for f in findings if f.source_url)

    cards = []
    for f in findings:
        cards.append(
            f"<div class='card {f.severity}'>"
            f"<div>{_sev_badge(f.severity)}<span class='team'>{f.team}</span></div>"
            f"<h2>{f.title or f.topic}</h2>"
            f"<div class='summary'>{f.summary}</div>"
            f"<div class='rationale'>{f.rationale}</div>"
            f"<div class='meta' style='margin-top:6px'>{f.topic}</div>"
            f"{_cite(f)}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Compliance Monitoring Findings</h1>
      <p class="sub">Citation-precise regulatory changes from trusted domains —
      every finding links to a You.com Research source with snippet provenance.</p>
      <div class="stats">
        <span class="pill"><b>{len(findings)}</b> findings</span>
        <span class="pill"><b>{cited}</b> cited You.com sources</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{counts['action']}</b> action</span>
        <span class="pill" style="background:#fdf3e1;color:#b7791f">
          <b>{counts['watch']}</b> watch</span>
        <span class="pill" style="background:#e3f1fb;color:#2b6cb0">
          <b>{counts['info']}</b> info</span>
      </div>
      {''.join(cards) or "<p class='empty'>No findings in this window.</p>"}
      <p class="yoube">Findings retrieved via the You.com Research API with
      <code>source_control</code> domain allowlists and freshness filters.
      Flyte logs which agent called which query and got which document — full
      prompt &rarr; citation lineage for audit.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_watch_items() -> list[WatchItem]:
    return [
        WatchItem(
            topic="FDA guidance on AI/ML-enabled medical device software",
            trusted_domains=["fda.gov", "federalregister.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="SEC climate-related disclosure rules for public companies",
            trusted_domains=["sec.gov", "federalregister.gov"],
            team="legal",
        ),
        WatchItem(
            topic="OFAC sanctions list additions and updates",
            trusted_domains=["treasury.gov", "ofac.treasury.gov"],
            team="compliance",
        ),
        WatchItem(
            topic="State-level consumer data privacy laws and amendments",
            trusted_domains=["iapp.org", "oag.ca.gov"],
            team="legal",
        ),
        WatchItem(
            topic="FDA drug recalls and safety communications",
            trusted_domains=["fda.gov"],
            team="clinical",
        ),
        WatchItem(
            topic="HIPAA enforcement actions and guidance updates",
            trusted_domains=["hhs.gov"],
            team="compliance",
        ),
    ]

@env.task(report=True)
async def compliance_monitoring(
    watch_items: list[WatchItem] | None = None,
    freshness: str = "month",
) -> ComplianceReport:
    """Fan out across regulatory watch items and aggregate triaged findings."""
    if watch_items is None:
        watch_items = _default_watch_items()

    with flyte.group("monitor-watch-items"):
        results = await asyncio.gather(
            *[monitor_watch_item(item, freshness) for item in watch_items]
        )

    report = ComplianceReport(findings=[f for fs in results for f in fs])

    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(compliance_monitoring)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/compliance_monitoring_agent/main.py*

## Run the agent

### Create secrets

Get a You.com API key from the [You.com platform](https://you.com/platform) (see the [quickstart guide](https://you.com/docs/quickstart)). Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/).

Register both keys as Flyte secrets. The secret key names must match those declared in the `TaskEnvironment`:

```
flyte create secret youdotcom-api-key <YOUR_YOU_API_KEY>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run locally or remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/compliance_monitoring_agent):

```
cd v2/tutorials/compliance_monitoring_agent
uv run --script main.py
```

To test locally without Flyte secrets:

```
export YOU_API_KEY=<YOUR_YOU_API_KEY>
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>

uv run --script main.py
```

When the run completes, open the Flyte report to review findings grouped by severity, each with a verifiable You.com Research citation.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/field-data-enrichment-agent ===

# Field data enrichment agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/field_data_enrichment_agent).

This example demonstrates how to build an autonomous systems and field-data enrichment agent on Flyte. The agent enriches geo-tagged operational events — from autonomous vehicles, aircraft, satellites, or field sensors — with **real-world public context**: road closures, weather events, airspace changes, or local incidents tied to a geofence.

Operational data stays in your environment while public-web grounding queries go to the [You.com Search API](https://you.com/docs/search/overview). The API provides unified web and news results with `freshness` and `country` targeting, and [Claude](https://docs.anthropic.com/) via [LiteLLM](https://docs.litellm.ai/) summarizes the relevant context for each geo-tagged event.

Flyte provides:

- **Fan-out parallelism** across geo-tagged events
- **`cache="auto"`** so repeated geofence checks within the cache window reuse prior results
- **`@flyte.trace`** on every external call for lineage
- **Flyte reports** with operational severity and per-incident citations

![Field data enrichment agent report](https://www.union.ai/docs/v2/union/_static/images/tutorials/field_data_enrichment_agent/field-data-enrichment-data.png)

## Setting up the environment

The agent runs in a `TaskEnvironment` with secrets for the You.com and Anthropic API keys, automatic caching, and a container image built from the `uv` script dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "field_data_enrichment"
# params = ""
# ///
"""Autonomous systems & field-data enrichment agent.

Enriches geo-tagged operational events with real-world public context (road
closures, weather, incidents) using the You.com Search API with country +
freshness targeting, then uses Claude to summarize the relevant context. Only
public-web grounding queries leave the customer's cloud, never operational data.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="field-data-enrichment",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="field-data-enrichment", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class GeoEvent:
    event_id: str
    location: str
    country: str
    event_type: str

@dataclass
class Incident:
    description: str
    source_url: str
    published: str
    domain: str = ""
    author: str = ""
    favicon: str = ""
    snippet: str = ""
    section: str = "web"

@dataclass
class EnrichedEvent:
    event_id: str
    location: str
    context_summary: str
    severity: str
    incidents: list[Incident] = field(default_factory=list)

@dataclass
class EnrichmentReport:
    events: list[EnrichedEvent] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

@dataclass
class SearchHit:
    title: str
    url: str
    domain: str
    snippet: str
    published: str
    author: str
    favicon: str
    section: str

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_search(
    query: str, country: str, freshness: str = "day", count: int = 8
) -> list[SearchHit]:
    """Search the public web + news for context near a geofenced location."""
    params = {
        "query": query,
        "count": count,
        "freshness": freshness,
        "country": country,
    }
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict:
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

ENRICH_SYSTEM = """You are an operational-context analyst for autonomous and \
field systems. Given fresh local search results near a geofenced location, \
summarize the real-world context relevant to operations, extract discrete \
incidents (road closures, weather events, regulatory/airspace changes, local \
incidents), and assign an operational severity of 'none', 'low', 'medium', or \
'high'. Each incident must reference the supporting search result by its index. \
Respond ONLY with JSON:
{"context_summary": str, "severity": str, "incidents": [{"description": str, \
"source_index": int (the [n] of the supporting search result)}]}"""

# {{docs-fragment enrich_event}}
@env.task(retries=3)
async def enrich_event(event: GeoEvent, freshness: str) -> EnrichedEvent:
    """Ground one geo-tagged event in fresh public context."""
    query = f"{event.location} {event.event_type.replace('_', ' ')} road closure weather incident"
    hits = await you_search(query, country=event.country, freshness=freshness)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Location: {event.location}\n"
        f"Event type: {event.event_type}\n\n"
        f"Search results:\n{evidence or 'No results.'}"
    )
    parsed = await llm_json(ENRICH_SYSTEM, user)

    def _incident(it: dict) -> Incident:
        idx = int(it.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        return Incident(
            description=str(it.get("description", "")),
            source_url=src.url if src else "",
            published=src.published if src else "",
            domain=src.domain if src else "",
            author=src.author if src else "",
            favicon=src.favicon if src else "",
            snippet=src.snippet if src else "",
            section=src.section if src else "web",
        )

    incidents = [_incident(it) for it in (parsed.get("incidents", []) or [])]
    return EnrichedEvent(
        event_id=event.event_id,
        location=event.location,
        context_summary=str(parsed.get("context_summary", "")),
        severity=str(parsed.get("severity", "none")),
        incidents=incidents,
    )
# {{/docs-fragment enrich_event}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"high": 0, "medium": 1, "low": 2, "none": 3}
_SEVERITY_STYLE = {
    "high": ("#fdecea", "#c0392b"),
    "medium": ("#fdf3e1", "#b7791f"),
    "low": ("#e3f1fb", "#2b6cb0"),
    "none": ("#eef1f4", "#627d98"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.high { border-left-color:#c0392b; }
  .rpt .card.medium { border-left-color:#b7791f; }
  .rpt .card.low { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .loc { font-size:13px; color:#52606d; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .inc { font-size:13px; color:#334e68; padding:6px 0; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:5px 0 2px 14px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#eef1f4", "#627d98"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(it: Incident) -> str:
    """Render a rich You.com citation for an incident's supporting source."""
    if not it.source_url:
        return ""
    tag = (
        "<span class='ctag news'>news</span>"
        if it.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if it.published:
        meta_bits.append(it.published[:10])
    if it.author:
        meta_bits.append(f"by {it.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{it.snippet}&rdquo;</div>" if it.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{it.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{it.source_url}'><span class='cdom'>{it.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: EnrichmentReport) -> str:
    events = sorted(report.events, key=lambda e: _SEVERITY_ORDER.get(e.severity, 4))
    flagged = sum(1 for e in events if e.severity in ("high", "medium"))
    total_sources = sum(len(e.incidents) for e in events)

    cards = []
    for e in events:
        incidents = "".join(
            f"<div class='inc'>&bull; {it.description}{_cite(it)}</div>"
            for it in e.incidents
        )
        cards.append(
            f"<div class='card {e.severity}'>"
            f"<div>{_sev_badge(e.severity)}"
            f"<span class='loc'><b>{e.event_id}</b> &middot; {e.location}</span></div>"
            f"<div class='summary'>{e.context_summary or 'No relevant public context found.'}</div>"
            f"{incidents}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Field-Data Enrichment</h1>
      <p class="sub">Geo-tagged events grounded in fresh public context — each
      incident cites a timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(events)}</b> events</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{flagged}</b> flagged (high/medium)</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No events processed.</p>"}
      <p class="yoube">Public context retrieved via the You.com Search API with
      country + freshness targeting. Operational data never leaves the BYOC
      boundary — only public-web queries go out.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
DEFAULT_EVENTS = [
    GeoEvent("evt-1", "Mountain View, CA", "US", "road_closure_check"),
    GeoEvent("evt-2", "Tokyo, Japan", "JP", "weather"),
    GeoEvent("evt-3", "Austin, TX", "US", "road_closure_check"),
    GeoEvent("evt-4", "Phoenix, AZ", "US", "weather"),
    GeoEvent("evt-5", "London, UK", "GB", "incident"),
    GeoEvent("evt-6", "San Francisco, CA", "US", "incident"),
    GeoEvent("evt-7", "Seattle, WA", "US", "weather"),
    GeoEvent("evt-8", "Miami, FL", "US", "weather"),
    GeoEvent("evt-9", "Denver, CO", "US", "road_closure_check"),
    GeoEvent("evt-10", "Berlin, Germany", "DE", "incident"),
]

@env.task(report=True)
async def field_data_enrichment(
    events: list[GeoEvent] = DEFAULT_EVENTS,
    freshness: str = "day",
) -> EnrichmentReport:
    """Fan out across geo-tagged events and enrich each with public context."""
    with flyte.group("enrich-events"):
        enriched = await asyncio.gather(
            *[enrich_event(e, freshness) for e in events]
        )

    report = EnrichmentReport(events=list(enriched))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(field_data_enrichment)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/field_data_enrichment_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# ///
```

## Data types

Each `GeoEvent` carries an event ID, location, ISO country code for geo-targeting, and an event type. Enriched events include a context summary, operational severity, and discrete incidents with source citations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "field_data_enrichment"
# params = ""
# ///
"""Autonomous systems & field-data enrichment agent.

Enriches geo-tagged operational events with real-world public context (road
closures, weather, incidents) using the You.com Search API with country +
freshness targeting, then uses Claude to summarize the relevant context. Only
public-web grounding queries leave the customer's cloud, never operational data.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="field-data-enrichment",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="field-data-enrichment", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class GeoEvent:
    event_id: str
    location: str
    country: str
    event_type: str

@dataclass
class Incident:
    description: str
    source_url: str
    published: str
    domain: str = ""
    author: str = ""
    favicon: str = ""
    snippet: str = ""
    section: str = "web"

@dataclass
class EnrichedEvent:
    event_id: str
    location: str
    context_summary: str
    severity: str
    incidents: list[Incident] = field(default_factory=list)

@dataclass
class EnrichmentReport:
    events: list[EnrichedEvent] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

@dataclass
class SearchHit:
    title: str
    url: str
    domain: str
    snippet: str
    published: str
    author: str
    favicon: str
    section: str

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_search(
    query: str, country: str, freshness: str = "day", count: int = 8
) -> list[SearchHit]:
    """Search the public web + news for context near a geofenced location."""
    params = {
        "query": query,
        "count": count,
        "freshness": freshness,
        "country": country,
    }
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict:
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

ENRICH_SYSTEM = """You are an operational-context analyst for autonomous and \
field systems. Given fresh local search results near a geofenced location, \
summarize the real-world context relevant to operations, extract discrete \
incidents (road closures, weather events, regulatory/airspace changes, local \
incidents), and assign an operational severity of 'none', 'low', 'medium', or \
'high'. Each incident must reference the supporting search result by its index. \
Respond ONLY with JSON:
{"context_summary": str, "severity": str, "incidents": [{"description": str, \
"source_index": int (the [n] of the supporting search result)}]}"""

# {{docs-fragment enrich_event}}
@env.task(retries=3)
async def enrich_event(event: GeoEvent, freshness: str) -> EnrichedEvent:
    """Ground one geo-tagged event in fresh public context."""
    query = f"{event.location} {event.event_type.replace('_', ' ')} road closure weather incident"
    hits = await you_search(query, country=event.country, freshness=freshness)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Location: {event.location}\n"
        f"Event type: {event.event_type}\n\n"
        f"Search results:\n{evidence or 'No results.'}"
    )
    parsed = await llm_json(ENRICH_SYSTEM, user)

    def _incident(it: dict) -> Incident:
        idx = int(it.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        return Incident(
            description=str(it.get("description", "")),
            source_url=src.url if src else "",
            published=src.published if src else "",
            domain=src.domain if src else "",
            author=src.author if src else "",
            favicon=src.favicon if src else "",
            snippet=src.snippet if src else "",
            section=src.section if src else "web",
        )

    incidents = [_incident(it) for it in (parsed.get("incidents", []) or [])]
    return EnrichedEvent(
        event_id=event.event_id,
        location=event.location,
        context_summary=str(parsed.get("context_summary", "")),
        severity=str(parsed.get("severity", "none")),
        incidents=incidents,
    )
# {{/docs-fragment enrich_event}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"high": 0, "medium": 1, "low": 2, "none": 3}
_SEVERITY_STYLE = {
    "high": ("#fdecea", "#c0392b"),
    "medium": ("#fdf3e1", "#b7791f"),
    "low": ("#e3f1fb", "#2b6cb0"),
    "none": ("#eef1f4", "#627d98"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.high { border-left-color:#c0392b; }
  .rpt .card.medium { border-left-color:#b7791f; }
  .rpt .card.low { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .loc { font-size:13px; color:#52606d; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .inc { font-size:13px; color:#334e68; padding:6px 0; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:5px 0 2px 14px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#eef1f4", "#627d98"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(it: Incident) -> str:
    """Render a rich You.com citation for an incident's supporting source."""
    if not it.source_url:
        return ""
    tag = (
        "<span class='ctag news'>news</span>"
        if it.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if it.published:
        meta_bits.append(it.published[:10])
    if it.author:
        meta_bits.append(f"by {it.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{it.snippet}&rdquo;</div>" if it.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{it.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{it.source_url}'><span class='cdom'>{it.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: EnrichmentReport) -> str:
    events = sorted(report.events, key=lambda e: _SEVERITY_ORDER.get(e.severity, 4))
    flagged = sum(1 for e in events if e.severity in ("high", "medium"))
    total_sources = sum(len(e.incidents) for e in events)

    cards = []
    for e in events:
        incidents = "".join(
            f"<div class='inc'>&bull; {it.description}{_cite(it)}</div>"
            for it in e.incidents
        )
        cards.append(
            f"<div class='card {e.severity}'>"
            f"<div>{_sev_badge(e.severity)}"
            f"<span class='loc'><b>{e.event_id}</b> &middot; {e.location}</span></div>"
            f"<div class='summary'>{e.context_summary or 'No relevant public context found.'}</div>"
            f"{incidents}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Field-Data Enrichment</h1>
      <p class="sub">Geo-tagged events grounded in fresh public context — each
      incident cites a timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(events)}</b> events</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{flagged}</b> flagged (high/medium)</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No events processed.</p>"}
      <p class="yoube">Public context retrieved via the You.com Search API with
      country + freshness targeting. Operational data never leaves the BYOC
      boundary — only public-web queries go out.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
DEFAULT_EVENTS = [
    GeoEvent("evt-1", "Mountain View, CA", "US", "road_closure_check"),
    GeoEvent("evt-2", "Tokyo, Japan", "JP", "weather"),
    GeoEvent("evt-3", "Austin, TX", "US", "road_closure_check"),
    GeoEvent("evt-4", "Phoenix, AZ", "US", "weather"),
    GeoEvent("evt-5", "London, UK", "GB", "incident"),
    GeoEvent("evt-6", "San Francisco, CA", "US", "incident"),
    GeoEvent("evt-7", "Seattle, WA", "US", "weather"),
    GeoEvent("evt-8", "Miami, FL", "US", "weather"),
    GeoEvent("evt-9", "Denver, CO", "US", "road_closure_check"),
    GeoEvent("evt-10", "Berlin, Germany", "DE", "incident"),
]

@env.task(report=True)
async def field_data_enrichment(
    events: list[GeoEvent] = DEFAULT_EVENTS,
    freshness: str = "day",
) -> EnrichmentReport:
    """Fan out across geo-tagged events and enrich each with public context."""
    with flyte.group("enrich-events"):
        enriched = await asyncio.gather(
            *[enrich_event(e, freshness) for e in events]
        )

    report = EnrichmentReport(events=list(enriched))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(field_data_enrichment)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/field_data_enrichment_agent/main.py*

## Search with the You.com Search API

The `you_search` helper calls the [You.com Search API](https://you.com/docs/search/overview) with `freshness` and `country` parameters to retrieve location-relevant web and news results. See the [Search API reference](https://you.com/docs/api-reference/search/v1-search) for supported country codes and freshness values.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "field_data_enrichment"
# params = ""
# ///
"""Autonomous systems & field-data enrichment agent.

Enriches geo-tagged operational events with real-world public context (road
closures, weather, incidents) using the You.com Search API with country +
freshness targeting, then uses Claude to summarize the relevant context. Only
public-web grounding queries leave the customer's cloud, never operational data.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="field-data-enrichment",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="field-data-enrichment", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class GeoEvent:
    event_id: str
    location: str
    country: str
    event_type: str

@dataclass
class Incident:
    description: str
    source_url: str
    published: str
    domain: str = ""
    author: str = ""
    favicon: str = ""
    snippet: str = ""
    section: str = "web"

@dataclass
class EnrichedEvent:
    event_id: str
    location: str
    context_summary: str
    severity: str
    incidents: list[Incident] = field(default_factory=list)

@dataclass
class EnrichmentReport:
    events: list[EnrichedEvent] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

@dataclass
class SearchHit:
    title: str
    url: str
    domain: str
    snippet: str
    published: str
    author: str
    favicon: str
    section: str

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_search(
    query: str, country: str, freshness: str = "day", count: int = 8
) -> list[SearchHit]:
    """Search the public web + news for context near a geofenced location."""
    params = {
        "query": query,
        "count": count,
        "freshness": freshness,
        "country": country,
    }
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict:
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

ENRICH_SYSTEM = """You are an operational-context analyst for autonomous and \
field systems. Given fresh local search results near a geofenced location, \
summarize the real-world context relevant to operations, extract discrete \
incidents (road closures, weather events, regulatory/airspace changes, local \
incidents), and assign an operational severity of 'none', 'low', 'medium', or \
'high'. Each incident must reference the supporting search result by its index. \
Respond ONLY with JSON:
{"context_summary": str, "severity": str, "incidents": [{"description": str, \
"source_index": int (the [n] of the supporting search result)}]}"""

# {{docs-fragment enrich_event}}
@env.task(retries=3)
async def enrich_event(event: GeoEvent, freshness: str) -> EnrichedEvent:
    """Ground one geo-tagged event in fresh public context."""
    query = f"{event.location} {event.event_type.replace('_', ' ')} road closure weather incident"
    hits = await you_search(query, country=event.country, freshness=freshness)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Location: {event.location}\n"
        f"Event type: {event.event_type}\n\n"
        f"Search results:\n{evidence or 'No results.'}"
    )
    parsed = await llm_json(ENRICH_SYSTEM, user)

    def _incident(it: dict) -> Incident:
        idx = int(it.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        return Incident(
            description=str(it.get("description", "")),
            source_url=src.url if src else "",
            published=src.published if src else "",
            domain=src.domain if src else "",
            author=src.author if src else "",
            favicon=src.favicon if src else "",
            snippet=src.snippet if src else "",
            section=src.section if src else "web",
        )

    incidents = [_incident(it) for it in (parsed.get("incidents", []) or [])]
    return EnrichedEvent(
        event_id=event.event_id,
        location=event.location,
        context_summary=str(parsed.get("context_summary", "")),
        severity=str(parsed.get("severity", "none")),
        incidents=incidents,
    )
# {{/docs-fragment enrich_event}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"high": 0, "medium": 1, "low": 2, "none": 3}
_SEVERITY_STYLE = {
    "high": ("#fdecea", "#c0392b"),
    "medium": ("#fdf3e1", "#b7791f"),
    "low": ("#e3f1fb", "#2b6cb0"),
    "none": ("#eef1f4", "#627d98"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.high { border-left-color:#c0392b; }
  .rpt .card.medium { border-left-color:#b7791f; }
  .rpt .card.low { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .loc { font-size:13px; color:#52606d; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .inc { font-size:13px; color:#334e68; padding:6px 0; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:5px 0 2px 14px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#eef1f4", "#627d98"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(it: Incident) -> str:
    """Render a rich You.com citation for an incident's supporting source."""
    if not it.source_url:
        return ""
    tag = (
        "<span class='ctag news'>news</span>"
        if it.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if it.published:
        meta_bits.append(it.published[:10])
    if it.author:
        meta_bits.append(f"by {it.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{it.snippet}&rdquo;</div>" if it.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{it.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{it.source_url}'><span class='cdom'>{it.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: EnrichmentReport) -> str:
    events = sorted(report.events, key=lambda e: _SEVERITY_ORDER.get(e.severity, 4))
    flagged = sum(1 for e in events if e.severity in ("high", "medium"))
    total_sources = sum(len(e.incidents) for e in events)

    cards = []
    for e in events:
        incidents = "".join(
            f"<div class='inc'>&bull; {it.description}{_cite(it)}</div>"
            for it in e.incidents
        )
        cards.append(
            f"<div class='card {e.severity}'>"
            f"<div>{_sev_badge(e.severity)}"
            f"<span class='loc'><b>{e.event_id}</b> &middot; {e.location}</span></div>"
            f"<div class='summary'>{e.context_summary or 'No relevant public context found.'}</div>"
            f"{incidents}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Field-Data Enrichment</h1>
      <p class="sub">Geo-tagged events grounded in fresh public context — each
      incident cites a timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(events)}</b> events</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{flagged}</b> flagged (high/medium)</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No events processed.</p>"}
      <p class="yoube">Public context retrieved via the You.com Search API with
      country + freshness targeting. Operational data never leaves the BYOC
      boundary — only public-web queries go out.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
DEFAULT_EVENTS = [
    GeoEvent("evt-1", "Mountain View, CA", "US", "road_closure_check"),
    GeoEvent("evt-2", "Tokyo, Japan", "JP", "weather"),
    GeoEvent("evt-3", "Austin, TX", "US", "road_closure_check"),
    GeoEvent("evt-4", "Phoenix, AZ", "US", "weather"),
    GeoEvent("evt-5", "London, UK", "GB", "incident"),
    GeoEvent("evt-6", "San Francisco, CA", "US", "incident"),
    GeoEvent("evt-7", "Seattle, WA", "US", "weather"),
    GeoEvent("evt-8", "Miami, FL", "US", "weather"),
    GeoEvent("evt-9", "Denver, CO", "US", "road_closure_check"),
    GeoEvent("evt-10", "Berlin, Germany", "DE", "incident"),
]

@env.task(report=True)
async def field_data_enrichment(
    events: list[GeoEvent] = DEFAULT_EVENTS,
    freshness: str = "day",
) -> EnrichmentReport:
    """Fan out across geo-tagged events and enrich each with public context."""
    with flyte.group("enrich-events"):
        enriched = await asyncio.gather(
            *[enrich_event(e, freshness) for e in events]
        )

    report = EnrichmentReport(events=list(enriched))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(field_data_enrichment)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/field_data_enrichment_agent/main.py*

## Enrich one event

The `enrich_event` task builds a location- and type-scoped query, calls the You.com Search API, and asks Claude to summarize relevant real-world context, extract discrete incidents, and assign an operational severity — all grounded in the returned sources.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "field_data_enrichment"
# params = ""
# ///
"""Autonomous systems & field-data enrichment agent.

Enriches geo-tagged operational events with real-world public context (road
closures, weather, incidents) using the You.com Search API with country +
freshness targeting, then uses Claude to summarize the relevant context. Only
public-web grounding queries leave the customer's cloud, never operational data.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="field-data-enrichment",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="field-data-enrichment", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class GeoEvent:
    event_id: str
    location: str
    country: str
    event_type: str

@dataclass
class Incident:
    description: str
    source_url: str
    published: str
    domain: str = ""
    author: str = ""
    favicon: str = ""
    snippet: str = ""
    section: str = "web"

@dataclass
class EnrichedEvent:
    event_id: str
    location: str
    context_summary: str
    severity: str
    incidents: list[Incident] = field(default_factory=list)

@dataclass
class EnrichmentReport:
    events: list[EnrichedEvent] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

@dataclass
class SearchHit:
    title: str
    url: str
    domain: str
    snippet: str
    published: str
    author: str
    favicon: str
    section: str

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_search(
    query: str, country: str, freshness: str = "day", count: int = 8
) -> list[SearchHit]:
    """Search the public web + news for context near a geofenced location."""
    params = {
        "query": query,
        "count": count,
        "freshness": freshness,
        "country": country,
    }
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict:
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

ENRICH_SYSTEM = """You are an operational-context analyst for autonomous and \
field systems. Given fresh local search results near a geofenced location, \
summarize the real-world context relevant to operations, extract discrete \
incidents (road closures, weather events, regulatory/airspace changes, local \
incidents), and assign an operational severity of 'none', 'low', 'medium', or \
'high'. Each incident must reference the supporting search result by its index. \
Respond ONLY with JSON:
{"context_summary": str, "severity": str, "incidents": [{"description": str, \
"source_index": int (the [n] of the supporting search result)}]}"""

# {{docs-fragment enrich_event}}
@env.task(retries=3)
async def enrich_event(event: GeoEvent, freshness: str) -> EnrichedEvent:
    """Ground one geo-tagged event in fresh public context."""
    query = f"{event.location} {event.event_type.replace('_', ' ')} road closure weather incident"
    hits = await you_search(query, country=event.country, freshness=freshness)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Location: {event.location}\n"
        f"Event type: {event.event_type}\n\n"
        f"Search results:\n{evidence or 'No results.'}"
    )
    parsed = await llm_json(ENRICH_SYSTEM, user)

    def _incident(it: dict) -> Incident:
        idx = int(it.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        return Incident(
            description=str(it.get("description", "")),
            source_url=src.url if src else "",
            published=src.published if src else "",
            domain=src.domain if src else "",
            author=src.author if src else "",
            favicon=src.favicon if src else "",
            snippet=src.snippet if src else "",
            section=src.section if src else "web",
        )

    incidents = [_incident(it) for it in (parsed.get("incidents", []) or [])]
    return EnrichedEvent(
        event_id=event.event_id,
        location=event.location,
        context_summary=str(parsed.get("context_summary", "")),
        severity=str(parsed.get("severity", "none")),
        incidents=incidents,
    )
# {{/docs-fragment enrich_event}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"high": 0, "medium": 1, "low": 2, "none": 3}
_SEVERITY_STYLE = {
    "high": ("#fdecea", "#c0392b"),
    "medium": ("#fdf3e1", "#b7791f"),
    "low": ("#e3f1fb", "#2b6cb0"),
    "none": ("#eef1f4", "#627d98"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.high { border-left-color:#c0392b; }
  .rpt .card.medium { border-left-color:#b7791f; }
  .rpt .card.low { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .loc { font-size:13px; color:#52606d; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .inc { font-size:13px; color:#334e68; padding:6px 0; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:5px 0 2px 14px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#eef1f4", "#627d98"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(it: Incident) -> str:
    """Render a rich You.com citation for an incident's supporting source."""
    if not it.source_url:
        return ""
    tag = (
        "<span class='ctag news'>news</span>"
        if it.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if it.published:
        meta_bits.append(it.published[:10])
    if it.author:
        meta_bits.append(f"by {it.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{it.snippet}&rdquo;</div>" if it.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{it.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{it.source_url}'><span class='cdom'>{it.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: EnrichmentReport) -> str:
    events = sorted(report.events, key=lambda e: _SEVERITY_ORDER.get(e.severity, 4))
    flagged = sum(1 for e in events if e.severity in ("high", "medium"))
    total_sources = sum(len(e.incidents) for e in events)

    cards = []
    for e in events:
        incidents = "".join(
            f"<div class='inc'>&bull; {it.description}{_cite(it)}</div>"
            for it in e.incidents
        )
        cards.append(
            f"<div class='card {e.severity}'>"
            f"<div>{_sev_badge(e.severity)}"
            f"<span class='loc'><b>{e.event_id}</b> &middot; {e.location}</span></div>"
            f"<div class='summary'>{e.context_summary or 'No relevant public context found.'}</div>"
            f"{incidents}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Field-Data Enrichment</h1>
      <p class="sub">Geo-tagged events grounded in fresh public context — each
      incident cites a timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(events)}</b> events</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{flagged}</b> flagged (high/medium)</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No events processed.</p>"}
      <p class="yoube">Public context retrieved via the You.com Search API with
      country + freshness targeting. Operational data never leaves the BYOC
      boundary — only public-web queries go out.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
DEFAULT_EVENTS = [
    GeoEvent("evt-1", "Mountain View, CA", "US", "road_closure_check"),
    GeoEvent("evt-2", "Tokyo, Japan", "JP", "weather"),
    GeoEvent("evt-3", "Austin, TX", "US", "road_closure_check"),
    GeoEvent("evt-4", "Phoenix, AZ", "US", "weather"),
    GeoEvent("evt-5", "London, UK", "GB", "incident"),
    GeoEvent("evt-6", "San Francisco, CA", "US", "incident"),
    GeoEvent("evt-7", "Seattle, WA", "US", "weather"),
    GeoEvent("evt-8", "Miami, FL", "US", "weather"),
    GeoEvent("evt-9", "Denver, CO", "US", "road_closure_check"),
    GeoEvent("evt-10", "Berlin, Germany", "DE", "incident"),
]

@env.task(report=True)
async def field_data_enrichment(
    events: list[GeoEvent] = DEFAULT_EVENTS,
    freshness: str = "day",
) -> EnrichmentReport:
    """Fan out across geo-tagged events and enrich each with public context."""
    with flyte.group("enrich-events"):
        enriched = await asyncio.gather(
            *[enrich_event(e, freshness) for e in events]
        )

    report = EnrichmentReport(events=list(enriched))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(field_data_enrichment)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/field_data_enrichment_agent/main.py*

## Orchestration

The `field_data_enrichment` driver task fans out across all events and renders a Flyte report sorted by severity.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "field_data_enrichment"
# params = ""
# ///
"""Autonomous systems & field-data enrichment agent.

Enriches geo-tagged operational events with real-world public context (road
closures, weather, incidents) using the You.com Search API with country +
freshness targeting, then uses Claude to summarize the relevant context. Only
public-web grounding queries leave the customer's cloud, never operational data.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="field-data-enrichment",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="field-data-enrichment", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    cache="auto",
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class GeoEvent:
    event_id: str
    location: str
    country: str
    event_type: str

@dataclass
class Incident:
    description: str
    source_url: str
    published: str
    domain: str = ""
    author: str = ""
    favicon: str = ""
    snippet: str = ""
    section: str = "web"

@dataclass
class EnrichedEvent:
    event_id: str
    location: str
    context_summary: str
    severity: str
    incidents: list[Incident] = field(default_factory=list)

@dataclass
class EnrichmentReport:
    events: list[EnrichedEvent] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_search}}
YOU_SEARCH_URL = "https://ydc-index.io/v1/search"

@dataclass
class SearchHit:
    title: str
    url: str
    domain: str
    snippet: str
    published: str
    author: str
    favicon: str
    section: str

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon(item: dict, url: str) -> str:
    return item.get("favicon_url") or (
        f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"
    )

async def _you_get(url: str, params: dict, timeout: float = 60.0) -> dict:
    """GET with exponential backoff + jitter on 429 rate limits."""
    import asyncio
    import random

    import httpx

    headers = {"X-API-Key": os.environ["YOU_API_KEY"]}
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.get(url, headers=headers, params=params)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_search(
    query: str, country: str, freshness: str = "day", count: int = 8
) -> list[SearchHit]:
    """Search the public web + news for context near a geofenced location."""
    params = {
        "query": query,
        "count": count,
        "freshness": freshness,
        "country": country,
    }
    data = await _you_get(YOU_SEARCH_URL, params)

    results = data.get("results", {})
    hits: list[SearchHit] = []
    for section in ("news", "web"):
        for item in results.get(section, []) or []:
            snippets = item.get("snippets") or []
            url = item.get("url", "")
            hits.append(
                SearchHit(
                    title=item.get("title", ""),
                    url=url,
                    domain=_domain(url),
                    snippet=(snippets[0] if snippets else item.get("description", "")),
                    published=item.get("page_age", "") or "",
                    author=", ".join(item.get("authors") or []),
                    favicon=_favicon(item, url),
                    section=section,
                )
            )
    return hits
# {{/docs-fragment you_search}}

# {{docs-fragment llm}}
@flyte.trace
async def llm_json(system: str, user: str) -> dict:
    from litellm import acompletion

    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
        max_tokens=1536,
    )
    parsed = _parse_json(resp.choices[0].message.content)
    return parsed if isinstance(parsed, dict) else {}

def _parse_json(text: str) -> dict | list:
    text = text.strip()
    if text.startswith("```"):
        text = text.split("```", 2)[1]
        if text.lstrip().startswith("json"):
            text = text.lstrip()[4:]
    start = min((i for i in (text.find("{"), text.find("[")) if i != -1), default=0)
    end = max(text.rfind("}"), text.rfind("]")) + 1
    return json.loads(text[start:end])
# {{/docs-fragment llm}}

ENRICH_SYSTEM = """You are an operational-context analyst for autonomous and \
field systems. Given fresh local search results near a geofenced location, \
summarize the real-world context relevant to operations, extract discrete \
incidents (road closures, weather events, regulatory/airspace changes, local \
incidents), and assign an operational severity of 'none', 'low', 'medium', or \
'high'. Each incident must reference the supporting search result by its index. \
Respond ONLY with JSON:
{"context_summary": str, "severity": str, "incidents": [{"description": str, \
"source_index": int (the [n] of the supporting search result)}]}"""

# {{docs-fragment enrich_event}}
@env.task(retries=3)
async def enrich_event(event: GeoEvent, freshness: str) -> EnrichedEvent:
    """Ground one geo-tagged event in fresh public context."""
    query = f"{event.location} {event.event_type.replace('_', ' ')} road closure weather incident"
    hits = await you_search(query, country=event.country, freshness=freshness)

    evidence = "\n\n".join(
        f"[{i + 1}] {h.title} ({h.published}) — {h.domain}\n{h.url}\n{h.snippet}"
        for i, h in enumerate(hits)
    )
    user = (
        f"Location: {event.location}\n"
        f"Event type: {event.event_type}\n\n"
        f"Search results:\n{evidence or 'No results.'}"
    )
    parsed = await llm_json(ENRICH_SYSTEM, user)

    def _incident(it: dict) -> Incident:
        idx = int(it.get("source_index", 0) or 0)
        src = hits[idx - 1] if 1 <= idx <= len(hits) else None
        return Incident(
            description=str(it.get("description", "")),
            source_url=src.url if src else "",
            published=src.published if src else "",
            domain=src.domain if src else "",
            author=src.author if src else "",
            favicon=src.favicon if src else "",
            snippet=src.snippet if src else "",
            section=src.section if src else "web",
        )

    incidents = [_incident(it) for it in (parsed.get("incidents", []) or [])]
    return EnrichedEvent(
        event_id=event.event_id,
        location=event.location,
        context_summary=str(parsed.get("context_summary", "")),
        severity=str(parsed.get("severity", "none")),
        incidents=incidents,
    )
# {{/docs-fragment enrich_event}}

# {{docs-fragment report}}
_SEVERITY_ORDER = {"high": 0, "medium": 1, "low": 2, "none": 3}
_SEVERITY_STYLE = {
    "high": ("#fdecea", "#c0392b"),
    "medium": ("#fdf3e1", "#b7791f"),
    "low": ("#e3f1fb", "#2b6cb0"),
    "none": ("#eef1f4", "#627d98"),
}

REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:16px 18px;
               margin:0 0 14px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; border-left:4px solid #cbd2d9; }
  .rpt .card.high { border-left-color:#c0392b; }
  .rpt .card.medium { border-left-color:#b7791f; }
  .rpt .card.low { border-left-color:#2b6cb0; }
  .rpt .card h2 { font-size:15px; margin:0 0 6px; color:#102a43; }
  .rpt .sev { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; text-transform:uppercase;
              letter-spacing:.03em; margin-right:8px; }
  .rpt .loc { font-size:13px; color:#52606d; }
  .rpt .summary { margin:8px 0; font-size:14px; line-height:1.45; }
  .rpt .inc { font-size:13px; color:#334e68; padding:6px 0; }
  .rpt .meta { color:#829ab1; font-size:12px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:5px 0 2px 14px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#bcccdc; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .ctag.news { background:#e8833a; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _sev_badge(sev: str) -> str:
    bg, fg = _SEVERITY_STYLE.get(sev, ("#eef1f4", "#627d98"))
    return f"<span class='sev' style='background:{bg};color:{fg}'>{sev}</span>"

def _cite(it: Incident) -> str:
    """Render a rich You.com citation for an incident's supporting source."""
    if not it.source_url:
        return ""
    tag = (
        "<span class='ctag news'>news</span>"
        if it.section == "news"
        else "<span class='ctag'>web</span>"
    )
    meta_bits = []
    if it.published:
        meta_bits.append(it.published[:10])
    if it.author:
        meta_bits.append(f"by {it.author}")
    meta = " &middot; ".join(meta_bits)
    snip = f"<div class='csnip'>&ldquo;{it.snippet}&rdquo;</div>" if it.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{it.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{it.source_url}'><span class='cdom'>{it.domain or 'source'}</span></a>{tag}"
        f"<div class='cmeta'>{meta}</div>{snip}</div></div>"
    )

def _render_report(report: EnrichmentReport) -> str:
    events = sorted(report.events, key=lambda e: _SEVERITY_ORDER.get(e.severity, 4))
    flagged = sum(1 for e in events if e.severity in ("high", "medium"))
    total_sources = sum(len(e.incidents) for e in events)

    cards = []
    for e in events:
        incidents = "".join(
            f"<div class='inc'>&bull; {it.description}{_cite(it)}</div>"
            for it in e.incidents
        )
        cards.append(
            f"<div class='card {e.severity}'>"
            f"<div>{_sev_badge(e.severity)}"
            f"<span class='loc'><b>{e.event_id}</b> &middot; {e.location}</span></div>"
            f"<div class='summary'>{e.context_summary or 'No relevant public context found.'}</div>"
            f"{incidents}</div>"
        )

    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Field-Data Enrichment</h1>
      <p class="sub">Geo-tagged events grounded in fresh public context — each
      incident cites a timestamped You.com Search result.</p>
      <div class="stats">
        <span class="pill"><b>{len(events)}</b> events</span>
        <span class="pill" style="background:#fdecea;color:#c0392b">
          <b>{flagged}</b> flagged (high/medium)</span>
        <span class="pill"><b>{total_sources}</b> cited You.com sources</span>
      </div>
      {''.join(cards) or "<p class='empty'>No events processed.</p>"}
      <p class="yoube">Public context retrieved via the You.com Search API with
      country + freshness targeting. Operational data never leaves the BYOC
      boundary — only public-web queries go out.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
DEFAULT_EVENTS = [
    GeoEvent("evt-1", "Mountain View, CA", "US", "road_closure_check"),
    GeoEvent("evt-2", "Tokyo, Japan", "JP", "weather"),
    GeoEvent("evt-3", "Austin, TX", "US", "road_closure_check"),
    GeoEvent("evt-4", "Phoenix, AZ", "US", "weather"),
    GeoEvent("evt-5", "London, UK", "GB", "incident"),
    GeoEvent("evt-6", "San Francisco, CA", "US", "incident"),
    GeoEvent("evt-7", "Seattle, WA", "US", "weather"),
    GeoEvent("evt-8", "Miami, FL", "US", "weather"),
    GeoEvent("evt-9", "Denver, CO", "US", "road_closure_check"),
    GeoEvent("evt-10", "Berlin, Germany", "DE", "incident"),
]

@env.task(report=True)
async def field_data_enrichment(
    events: list[GeoEvent] = DEFAULT_EVENTS,
    freshness: str = "day",
) -> EnrichmentReport:
    """Fan out across geo-tagged events and enrich each with public context."""
    with flyte.group("enrich-events"):
        enriched = await asyncio.gather(
            *[enrich_event(e, freshness) for e in events]
        )

    report = EnrichmentReport(events=list(enriched))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(field_data_enrichment)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/field_data_enrichment_agent/main.py*

## Run the agent

### Create secrets

Get a You.com API key from the [You.com platform](https://you.com/platform) (see the [quickstart guide](https://you.com/docs/quickstart)). Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/).

Register both keys as Flyte secrets. The secret key names must match those declared in the `TaskEnvironment`:

```
flyte create secret youdotcom-api-key <YOUR_YOU_API_KEY>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run locally or remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/field_data_enrichment_agent):

```
cd v2/tutorials/field_data_enrichment_agent
uv run --script main.py
```

To test locally without Flyte secrets:

```
export YOU_API_KEY=<YOUR_YOU_API_KEY>
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>

uv run --script main.py
```

When the run completes, open the Flyte report to review enriched events with operational severity and timestamped You.com source citations for each incident.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/agents/support-resolution-agent ===

# Support resolution agent

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/support_resolution_agent).

This example demonstrates how to build a customer-support and field-service resolution agent on Flyte. The agent resolves tickets that need current public information — return policies, weather advisories, product recalls, manufacturer specs — and drafts a customer-ready reply with sources a human agent can verify before sending.

The [You.com Research API](https://you.com/docs/research/overview) grounds each ticket in fresh, citable sources. [Claude](https://docs.anthropic.com/) via [LiteLLM](https://docs.litellm.ai/) turns that research into a reply draft. With `research_effort="lite"`, the research step stays fast enough for human-in-the-loop support flows.

Flyte provides:

- **Fan-out parallelism** across support tickets
- **`@flyte.trace`** on every external call for lineage
- A **two-step pipeline** per ticket: ground the answer, then draft the reply
- **Flyte reports** with draft replies and verifiable source citations

![Support resolution agent report](https://www.union.ai/docs/v2/union/_static/images/tutorials/support_resolution_agent/support-resolutions-agent.png)

## Setting up the environment

The agent runs in a `TaskEnvironment` with secrets for the You.com and Anthropic API keys and a container image built from the `uv` script dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# ///
```

## Data types

Each `Ticket` carries a ticket ID, a customer question, and optional product or vendor context. The final `Resolution` includes the grounded answer, a draft reply, and the list of You.com sources.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Ground answers with the You.com Research API

The `you_research` helper calls the [You.com Research API](https://you.com/docs/research/overview) with a configurable `research_effort`. For support use cases, `lite` provides a fast, citation-backed answer suitable for real-time, human-in-the-loop flows. See the [Research API reference](https://you.com/docs/api-reference/research/v1-research) for effort levels and parameters.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Ground one ticket

The `ground_answer` task combines the ticket question and context into a research query and collects the grounded answer plus structured sources from the Research API response.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Draft a customer-ready reply

The `draft_reply` task turns the grounded answer into a concise, friendly reply that cites source URLs inline so a human agent can verify before sending.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Resolve one ticket

Each ticket runs `ground_answer` followed by `draft_reply` in sequence.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Orchestration

The `support_resolution` driver task fans out across all tickets and renders a Flyte report with every draft reply and its sources.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.4.0",
#     "httpx>=0.27.0",
#     "litellm>=1.72.0",
# ]
# main = "support_resolution"
# params = ""
# ///
"""Customer-support & field-service resolution agent.

Grounds a support ticket in fresh, public, citable sources via the You.com
Research API (low effort for low latency, human-in-the-loop use), then uses
Claude to draft a customer-ready reply that cites its sources inline so a human
agent can verify before sending.
"""

# {{docs-fragment env}}
import asyncio
import json
import os
from dataclasses import dataclass, field

import flyte

MODEL = "anthropic/claude-haiku-4-5"

env = flyte.TaskEnvironment(
    name="support-resolution",
    secrets=[
        flyte.Secret(key="youdotcom-api-key", as_env_var="YOU_API_KEY"),
        flyte.Secret(key="internal-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="support-resolution", pre=True),
    resources=flyte.Resources(cpu="1", memory="1Gi"),
)
# {{/docs-fragment env}}

# {{docs-fragment data_types}}
@dataclass
class Source:
    title: str
    url: str
    snippet: str
    domain: str = ""
    favicon: str = ""

def _domain(url: str) -> str:
    from urllib.parse import urlparse

    try:
        return urlparse(url).netloc.replace("www.", "")
    except Exception:
        return ""

def _favicon_for(url: str) -> str:
    return f"https://ydc-index.io/favicon?domain={_domain(url)}&size=128"

@dataclass
class Ticket:
    ticket_id: str
    question: str
    context: str = ""

@dataclass
class Grounding:
    answer: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class Resolution:
    ticket_id: str
    ticket: str
    grounded_answer: str
    draft_reply: str
    sources: list[Source] = field(default_factory=list)

@dataclass
class ResolutionReport:
    resolutions: list[Resolution] = field(default_factory=list)
# {{/docs-fragment data_types}}

# {{docs-fragment you_research}}
YOU_RESEARCH_URL = "https://api.you.com/v1/research"

async def _you_post(url: str, body: dict, timeout: float = 120.0) -> dict:
    """POST with exponential backoff + jitter on 429 rate limits."""
    import random

    import httpx

    headers = {
        "X-API-Key": os.environ["YOU_API_KEY"],
        "Content-Type": "application/json",
    }
    async with httpx.AsyncClient(timeout=timeout) as client:
        for attempt in range(7):
            resp = await client.post(url, headers=headers, json=body)
            if resp.status_code == 429 and attempt < 6:
                wait = float(resp.headers.get("retry-after") or 0) or min(2**attempt, 30)
                await asyncio.sleep(wait + random.uniform(0, 2))
                continue
            resp.raise_for_status()
            return resp.json()
    resp.raise_for_status()
    return resp.json()

@flyte.trace
async def you_research(question: str, research_effort: str = "lite") -> dict:
    """Fast, citation-backed grounding for a support question."""
    body = {"input": question, "research_effort": research_effort}
    return await _you_post(YOU_RESEARCH_URL, body)
# {{/docs-fragment you_research}}

# {{docs-fragment ground_answer}}
@env.task(retries=3)
async def ground_answer(ticket: str, context: str, research_effort: str) -> Grounding:
    """Ground the ticket in fresh public sources via the Research API."""
    question = ticket if not context else f"{ticket}\n\nContext: {context}"
    result = await you_research(question, research_effort)

    output = result.get("output", {})
    answer = output.get("content", "")
    if not isinstance(answer, str):
        answer = json.dumps(answer)

    sources = []
    for s in output.get("sources", []) or []:
        url = str(s.get("url", ""))
        sources.append(
            Source(
                title=str(s.get("title", "") or url),
                url=url,
                snippet=str((s.get("snippets") or [""])[0]),
                domain=_domain(url),
                favicon=_favicon_for(url),
            )
        )
    return Grounding(answer=answer, sources=sources)
# {{/docs-fragment ground_answer}}

# {{docs-fragment draft_reply}}
@flyte.trace
async def _draft(ticket: str, answer: str, sources_text: str) -> str:
    from litellm import acompletion

    system = (
        "You are a senior customer-support agent. Using ONLY the grounded "
        "answer and sources provided, draft a concise, friendly, customer-ready "
        "reply. Cite the relevant source URL inline in parentheses after any "
        "factual claim so a human agent can verify before sending. If the "
        "sources do not answer the question, say so plainly."
    )
    user = (
        f"Customer ticket: {ticket}\n\n"
        f"Grounded answer:\n{answer}\n\nSources:\n{sources_text}"
    )
    resp = await acompletion(
        model=MODEL,
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.2,
        max_tokens=1024,
    )
    return resp.choices[0].message.content

@env.task
async def draft_reply(ticket: Ticket, grounding: Grounding) -> Resolution:
    """Turn the grounded answer into a cited, customer-ready reply."""
    sources_text = "\n".join(
        f"- {s.title} ({s.domain}): {s.url}\n  \"{s.snippet}\""
        for s in grounding.sources
    )
    reply = await _draft(ticket.question, grounding.answer, sources_text)

    return Resolution(
        ticket_id=ticket.ticket_id,
        ticket=ticket.question,
        grounded_answer=grounding.answer,
        draft_reply=reply,
        sources=grounding.sources,
    )
# {{/docs-fragment draft_reply}}

# {{docs-fragment resolve_ticket}}
async def resolve_ticket(ticket: Ticket, research_effort: str) -> Resolution:
    """Ground one ticket then draft its reply."""
    grounding = await ground_answer(ticket.question, ticket.context, research_effort)
    return await draft_reply(ticket, grounding)
# {{/docs-fragment resolve_ticket}}

# {{docs-fragment report}}
REPORT_CSS = """
<style>
  .rpt { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto,
         Helvetica, Arial, sans-serif; color:#1f2933; max-width:1040px;
         margin:0 auto; }
  .rpt h1 { font-size:22px; margin:0 0 4px; color:#102a43; }
  .rpt .sub { color:#647488; font-size:13px; margin:0 0 18px; }
  .rpt .stats { display:flex; gap:10px; flex-wrap:wrap; margin:0 0 22px; }
  .rpt .pill { background:#f0f4f8; border-radius:999px; padding:6px 14px;
               font-size:13px; color:#334e68; }
  .rpt .pill b { color:#102a43; }
  .rpt .card { border:1px solid #e4e7eb; border-radius:12px; padding:18px 20px;
               margin:0 0 16px; box-shadow:0 1px 3px rgba(16,42,67,0.06);
               background:#fff; }
  .rpt .tid { display:inline-block; font-size:11px; font-weight:700;
              padding:3px 9px; border-radius:6px; background:#e0e8f9;
              color:#2b4ba0; margin-right:8px; }
  .rpt .q { font-size:15px; font-weight:600; color:#102a43; margin:8px 0 12px; }
  .rpt .reply { background:#f7faf7; border:1px solid #e1ece1; border-radius:8px;
                padding:12px 14px; font-size:14px; line-height:1.55; }
  .rpt .reply h3 { font-size:11px; text-transform:uppercase; letter-spacing:.04em;
                   color:#3c8a5e; margin:0 0 8px; }
  .rpt .sources { margin-top:12px; }
  .rpt .sources h3 { font-size:11px; text-transform:uppercase; color:#627d98;
                     margin:0 0 8px; }
  .rpt a { color:#2b6cb0; text-decoration:none; }
  .rpt a:hover { text-decoration:underline; }
  .rpt .empty { color:#829ab1; font-style:italic; padding:8px 0; }
  .rpt .cite { display:flex; gap:9px; align-items:flex-start; background:#f7f9fb;
               border:1px solid #eef1f4; border-radius:8px; padding:7px 10px;
               margin:0 0 6px; }
  .rpt .cite img.fav { width:15px; height:15px; border-radius:3px; margin-top:2px;
                       flex:0 0 auto; background:#e4e7eb; }
  .rpt .cite .cb { font-size:12px; line-height:1.4; }
  .rpt .cite .cdom { font-weight:600; color:#334e68; }
  .rpt .cite .ctag { font-size:10px; font-weight:700; text-transform:uppercase;
                     color:#fff; background:#5b8def; border-radius:4px;
                     padding:1px 5px; margin-left:6px; }
  .rpt .cite .cmeta { color:#829ab1; }
  .rpt .cite .csnip { color:#52606d; font-style:italic; margin-top:2px; }
  .rpt .yoube { font-size:11px; color:#9aa5b1; margin-top:4px; }
</style>
"""

def _cite(s: Source) -> str:
    """Render a rich You.com Research citation for a support source."""
    if not s.url:
        return ""
    snip = f"<div class='csnip'>&ldquo;{s.snippet}&rdquo;</div>" if s.snippet else ""
    return (
        f"<div class='cite'><img class='fav' src='{s.favicon}' alt=''/>"
        f"<div class='cb'>"
        f"<a href='{s.url}'><span class='cdom'>{s.domain or 'source'}</span></a>"
        f"<span class='ctag'>research</span>"
        f"<div class='cmeta'>{s.title}</div>{snip}</div></div>"
    )

def _render_report(report: ResolutionReport) -> str:
    cards = []
    for res in report.resolutions:
        src = "".join(_cite(s) for s in res.sources[:8])
        reply_html = res.draft_reply.replace("\n", "<br/>")
        cards.append(
            f"<div class='card'>"
            f"<div><span class='tid'>{res.ticket_id}</span></div>"
            f"<div class='q'>{res.ticket}</div>"
            f"<div class='reply'><h3>Draft reply (for human review)</h3>{reply_html}</div>"
            + (f"<div class='sources'><h3>You.com sources ({len(res.sources)})</h3>{src}</div>" if src else "")
            + "</div>"
        )

    total_sources = sum(len(r.sources) for r in report.resolutions)
    return f"""
    {REPORT_CSS}
    <div class="rpt">
      <h1>Support Resolutions</h1>
      <p class="sub">Tickets grounded in fresh public sources via the You.com
      Research API — draft replies cite sources a human agent can verify.</p>
      <div class="stats">
        <span class="pill"><b>{len(report.resolutions)}</b> tickets</span>
        <span class="pill"><b>{total_sources}</b> You.com sources cited</span>
      </div>
      {''.join(cards) or "<p class='empty'>No tickets processed.</p>"}
      <p class="yoube">Each ticket grounded by the You.com Research API
      (<code>lite</code> effort for low-latency, human-in-the-loop use). Sources
      include domain, title, and snippet provenance — ready to paste into a
      customer reply with verification links.</p>
    </div>
    """
# {{/docs-fragment report}}

# {{docs-fragment driver}}
def _default_tickets() -> list[Ticket]:
    return [
        Ticket(
            "tkt-1",
            "Is there a recall on the DeWalt DCD777 cordless drill, and what should "
            "the customer do if there is?",
            "Customer purchased the drill recently and is asking about safety recalls.",
        ),
        Ticket(
            "tkt-2",
            "What is Sony's current return policy for the WH-1000XM5 headphones?",
            "Customer wants to return an opened pair bought 20 days ago.",
        ),
        Ticket(
            "tkt-3",
            "Are there any current weather advisories that could delay flights out of "
            "Denver International Airport today?",
            "Customer is worried about a connecting flight.",
        ),
        Ticket(
            "tkt-4",
            "What are the dimensions and weight capacity of the IKEA BEKANT desk?",
            "Customer is checking if it fits their space before resolving a complaint.",
        ),
        Ticket(
            "tkt-5",
            "Has Samsung issued any recall or safety notice for the Galaxy Z Fold5?",
            "Customer reports overheating and wants to know about known issues.",
        ),
        Ticket(
            "tkt-6",
            "What is the warranty period for a Dyson V15 Detect vacuum in the US?",
            "Customer's vacuum stopped working and asks about coverage.",
        ),
    ]

@env.task(report=True)
async def support_resolution(
    tickets: list[Ticket] | None = None,
    research_effort: str = "lite",
) -> ResolutionReport:
    """Fan out across support tickets, grounding and drafting cited replies."""
    if tickets is None:
        tickets = _default_tickets()

    with flyte.group("resolve-tickets"):
        resolutions = await asyncio.gather(
            *[resolve_ticket(t, research_effort) for t in tickets]
        )

    report = ResolutionReport(resolutions=list(resolutions))
    await flyte.report.replace.aio(_render_report(report), do_flush=True)
    await flyte.report.flush.aio()
    return report
# {{/docs-fragment driver}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(support_resolution)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/support_resolution_agent/main.py*

## Run the agent

### Create secrets

Get a You.com API key from the [You.com platform](https://you.com/platform) (see the [quickstart guide](https://you.com/docs/quickstart)). Get an Anthropic API key from the [Anthropic console](https://console.anthropic.com/).

Register both keys as Flyte secrets. The secret key names must match those declared in the `TaskEnvironment`:

```
flyte create secret youdotcom-api-key <YOUR_YOU_API_KEY>
flyte create secret internal-anthropic-api-key <YOUR_ANTHROPIC_API_KEY>
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md) for scoping and file-based secrets.

### Run locally or remotely

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/support_resolution_agent):

```
cd v2/tutorials/support_resolution_agent
uv run --script main.py
```

To test locally without Flyte secrets:

```
export YOU_API_KEY=<YOUR_YOU_API_KEY>
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>

uv run --script main.py
```

When the run completes, open the Flyte report to review draft replies for each ticket, with You.com source citations ready for a human agent to verify and paste into a customer response.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/context-engineering ===

# Context engineering

Tutorials for prompt engineering, prompt optimization, and context construction.

### **Context engineering > Automatic prompt engineering**

Easily run prompt optimization with real-time observability, traceability, and automatic recovery.

### **Context engineering > Text-to-SQL prompt optimization**

Learn how to turn natural language questions into SQL queries with Flyte and LlamaIndex, and explore prompt optimization in practice.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/context-engineering/text_to_sql ===

# Text-to-SQL prompt optimization

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/text_to_sql); based on work by [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/workflow/advanced_text_to_sql/).

Data analytics drives modern decision-making, but SQL often creates a bottleneck. Writing queries requires technical expertise, so non-technical stakeholders must rely on data teams. That translation layer slows everyone down.

Text-to-SQL narrows this gap by turning natural language into executable SQL queries. It lowers the barrier to structured data and makes databases accessible to more people.

In this tutorial, we build a Text-to-SQL workflow using LlamaIndex and evaluate it on the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (a benchmark of natural language questions over semi-structured tables). We then explore prompt optimization to see whether it improves accuracy and show how to track prompts and results over time. Along the way, we'll see what worked, what didn't, and what we learned about building durable evaluation pipelines. The pattern here can be adapted to your own datasets and workflows.

![Evaluation](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/text-to-sql/evaluation.png)

## Ingesting data

We start by ingesting the WikiTableQuestions dataset, which comes as CSV files, into a SQLite database. This database serves as the source of truth for our Text-to-SQL pipeline.

```
import asyncio
import fnmatch
import os
import re
import zipfile

import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env

# {{docs-fragment table_info}}
class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(..., description="table name (underscores only, no spaces)")
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

# {{/docs-fragment table_info}}

@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
    """Download and extract the dataset zip file if not already available."""
    output_zip = "data.zip"
    extract_dir = "wiki_table_questions"

    if not os.path.exists(zip_path):
        response = requests.get(zip_path, stream=True)
        with open(output_zip, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        output_zip = zip_path
        print(f"Using existing file {output_zip}")

    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(output_zip, "r") as zip_ref:
        for member in zip_ref.namelist():
            if fnmatch.fnmatch(member, search_glob):
                zip_ref.extract(member, extract_dir)

    remote_dir = await Dir.from_local(extract_dir)
    return remote_dir

async def read_csv_file(
    csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
    """Safely download and parse a CSV file into a DataFrame."""
    try:
        local_csv_file = await csv_file.download()
        return pd.read_csv(local_csv_file, nrows=nrows)
    except Exception as e:
        print(f"Error parsing {csv_file.path}: {e}")
        return None

def sanitize_column_name(col_name: str) -> str:
    """Sanitize column names by replacing spaces/special chars with underscores."""
    return re.sub(r"\W+", "_", col_name)

async def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    """Create a SQL table from a Pandas DataFrame."""
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Define table columns based on DataFrame dtypes
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    table = Table(table_name, metadata_obj, *columns)

    # Create table in database
    metadata_obj.create_all(engine)

    # Insert data into table
    with engine.begin() as conn:
        for _, row in df.iterrows():
            conn.execute(table.insert().values(**row.to_dict()))

@flyte.trace
async def create_table(
    csv_file: File, table_info: TableInfo, database_path: str
) -> str:
    """Safely create a table from CSV if parsing succeeds."""
    df = await read_csv_file(csv_file)
    if df is None:
        return "false"

    print(f"Creating table: {table_info.table_name}")

    engine = create_engine(f"sqlite:///{database_path}")
    metadata_obj = MetaData()

    await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
    return "true"

@flyte.trace
async def llm_structured_predict(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    feedback: str,
    llm: OpenAI,
) -> TableInfo:
    return llm.structured_predict(
        TableInfo,
        prompt_tmpl,
        feedback=feedback,
        table_str=df_str,
        exclude_table_name_list=str(list(table_names)),
    )

async def generate_unique_table_info(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    llm: OpenAI,
    tablename_lock: asyncio.Lock,
    retries: int = 3,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    last_table_name = None
    for attempt in range(retries):
        feedback = ""
        if attempt > 0:
            feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."

        table_info = await llm_structured_predict(
            df_str, table_names, prompt_tmpl, feedback, llm
        )
        last_table_name = table_info.table_name

        async with tablename_lock:
            if table_info.table_name not in table_names:
                table_names.append(table_info.table_name)
                return table_info

        print(f"Table name {table_info.table_name} already exists, retrying...")

    return None

async def process_csv_file(
    csv_file: File,
    table_names: list[str],
    semaphore: asyncio.Semaphore,
    tablename_lock: asyncio.Lock,
    llm: OpenAI,
    prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    async with semaphore:
        df = await read_csv_file(csv_file, nrows=10)
        if df is None:
            return None
        return await generate_unique_table_info(
            df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
        )

@env.task
async def extract_table_info(
    data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
    """Extract structured table information from CSV files."""
    table_names: list[str] = []
    semaphore = asyncio.Semaphore(concurrency)
    tablename_lock = asyncio.Lock()
    llm = OpenAI(model=model)

    prompt_str = """\
    Provide a JSON object with the following fields:

    - `table_name`: must be unique and descriptive (underscores only, no generic names).
    - `table_summary`: short and concise summary of the table.

    Do NOT use any of these table names: {exclude_table_name_list}

    Table:
    {table_str}

    {feedback}
    """
    prompt_tmpl = ChatPromptTemplate(
        message_templates=[ChatMessage.from_str(prompt_str, role="user")]
    )

    tasks = [
        process_csv_file(
            csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
        )
        async for csv_file in data_dir.walk()
    ]

    return await asyncio.gather(*tasks)

# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
    csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
    search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
    concurrency: int = 5,
    model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
    """Main data ingestion pipeline: download → extract → analyze → create DB."""
    data_dir = await download_and_extract(csv_zip_path, search_glob)
    table_infos = await extract_table_info(data_dir, model, concurrency)

    database_path = "wiki_table_questions.db"

    i = 0
    async for csv_file in data_dir.walk():
        table_info = table_infos[i]
        if table_info:
            ok = await create_table(csv_file, table_info, database_path)
            if ok == "false":
                table_infos[i] = None
        else:
            print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
        i += 1

    db_file = await File.from_local(database_path)
    return db_file, table_infos

# {{/docs-fragment data_ingestion}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py*

The ingestion step:

1. Downloads the dataset (a zip archive from GitHub).
2. Extracts the CSV files locally.
3. Generates table metadata (names and descriptions).
4. Creates corresponding tables in SQLite.

The Flyte task returns both the path to the database and the generated table metadata.

```
import asyncio
import fnmatch
import os
import re
import zipfile

import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env

# {{docs-fragment table_info}}
class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(..., description="table name (underscores only, no spaces)")
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

# {{/docs-fragment table_info}}

@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
    """Download and extract the dataset zip file if not already available."""
    output_zip = "data.zip"
    extract_dir = "wiki_table_questions"

    if not os.path.exists(zip_path):
        response = requests.get(zip_path, stream=True)
        with open(output_zip, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        output_zip = zip_path
        print(f"Using existing file {output_zip}")

    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(output_zip, "r") as zip_ref:
        for member in zip_ref.namelist():
            if fnmatch.fnmatch(member, search_glob):
                zip_ref.extract(member, extract_dir)

    remote_dir = await Dir.from_local(extract_dir)
    return remote_dir

async def read_csv_file(
    csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
    """Safely download and parse a CSV file into a DataFrame."""
    try:
        local_csv_file = await csv_file.download()
        return pd.read_csv(local_csv_file, nrows=nrows)
    except Exception as e:
        print(f"Error parsing {csv_file.path}: {e}")
        return None

def sanitize_column_name(col_name: str) -> str:
    """Sanitize column names by replacing spaces/special chars with underscores."""
    return re.sub(r"\W+", "_", col_name)

async def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    """Create a SQL table from a Pandas DataFrame."""
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Define table columns based on DataFrame dtypes
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    table = Table(table_name, metadata_obj, *columns)

    # Create table in database
    metadata_obj.create_all(engine)

    # Insert data into table
    with engine.begin() as conn:
        for _, row in df.iterrows():
            conn.execute(table.insert().values(**row.to_dict()))

@flyte.trace
async def create_table(
    csv_file: File, table_info: TableInfo, database_path: str
) -> str:
    """Safely create a table from CSV if parsing succeeds."""
    df = await read_csv_file(csv_file)
    if df is None:
        return "false"

    print(f"Creating table: {table_info.table_name}")

    engine = create_engine(f"sqlite:///{database_path}")
    metadata_obj = MetaData()

    await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
    return "true"

@flyte.trace
async def llm_structured_predict(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    feedback: str,
    llm: OpenAI,
) -> TableInfo:
    return llm.structured_predict(
        TableInfo,
        prompt_tmpl,
        feedback=feedback,
        table_str=df_str,
        exclude_table_name_list=str(list(table_names)),
    )

async def generate_unique_table_info(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    llm: OpenAI,
    tablename_lock: asyncio.Lock,
    retries: int = 3,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    last_table_name = None
    for attempt in range(retries):
        feedback = ""
        if attempt > 0:
            feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."

        table_info = await llm_structured_predict(
            df_str, table_names, prompt_tmpl, feedback, llm
        )
        last_table_name = table_info.table_name

        async with tablename_lock:
            if table_info.table_name not in table_names:
                table_names.append(table_info.table_name)
                return table_info

        print(f"Table name {table_info.table_name} already exists, retrying...")

    return None

async def process_csv_file(
    csv_file: File,
    table_names: list[str],
    semaphore: asyncio.Semaphore,
    tablename_lock: asyncio.Lock,
    llm: OpenAI,
    prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    async with semaphore:
        df = await read_csv_file(csv_file, nrows=10)
        if df is None:
            return None
        return await generate_unique_table_info(
            df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
        )

@env.task
async def extract_table_info(
    data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
    """Extract structured table information from CSV files."""
    table_names: list[str] = []
    semaphore = asyncio.Semaphore(concurrency)
    tablename_lock = asyncio.Lock()
    llm = OpenAI(model=model)

    prompt_str = """\
    Provide a JSON object with the following fields:

    - `table_name`: must be unique and descriptive (underscores only, no generic names).
    - `table_summary`: short and concise summary of the table.

    Do NOT use any of these table names: {exclude_table_name_list}

    Table:
    {table_str}

    {feedback}
    """
    prompt_tmpl = ChatPromptTemplate(
        message_templates=[ChatMessage.from_str(prompt_str, role="user")]
    )

    tasks = [
        process_csv_file(
            csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
        )
        async for csv_file in data_dir.walk()
    ]

    return await asyncio.gather(*tasks)

# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
    csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
    search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
    concurrency: int = 5,
    model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
    """Main data ingestion pipeline: download → extract → analyze → create DB."""
    data_dir = await download_and_extract(csv_zip_path, search_glob)
    table_infos = await extract_table_info(data_dir, model, concurrency)

    database_path = "wiki_table_questions.db"

    i = 0
    async for csv_file in data_dir.walk():
        table_info = table_infos[i]
        if table_info:
            ok = await create_table(csv_file, table_info, database_path)
            if ok == "false":
                table_infos[i] = None
        else:
            print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
        i += 1

    db_file = await File.from_local(database_path)
    return db_file, table_infos

# {{/docs-fragment data_ingestion}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py*

With Union artifacts (coming soon!), you'll be able to persist the ingested SQLite database as an artifact. This removes the need to rerun data ingestion in every pipeline.

## From question to SQL

Next, we define a workflow that converts natural language into executable SQL using a retrieval-augmented generation (RAG) approach.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The main `text_to_sql` task orchestrates the pipeline:

- Ingest data
- Build vector indices for each table
- Retrieve relevant tables and rows
- Generate SQL queries with an LLM
- Execute queries and synthesize answers

We use OpenAI GPT models with carefully structured prompts to maximize SQL correctness.

### Vector indexing

We index each table's rows semantically so the model can retrieve relevant examples during SQL generation.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

Each row becomes a text node stored in LlamaIndex’s `VectorStoreIndex`. This lets the system pull semantically similar rows when handling queries.

### Table retrieval and context building

We then retrieve the most relevant tables for a given query and build rich context that combines schema information with sample rows.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The retriever selects tables via semantic similarity, then attaches their schema and example rows. This context grounds the model's SQL generation in the database's actual structure and content.

### SQL generation and response synthesis

Finally, we generate SQL queries and produce natural language answers.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The SQL generation prompt includes schema, example rows, and formatting rules. After execution, the system returns a final answer.

At this point, we have an end-to-end Text-to-SQL pipeline: natural language questions go in, SQL queries run, and answers come back. To make this workflow production-ready, we leveraged several Flyte 2 capabilities. Caching ensures that repeated steps, like table ingestion or vector indexing, don’t need to rerun unnecessarily, saving time and compute. Containerization provides consistent, reproducible execution across environments, making it easier to scale and deploy. Observability features let us track every step of the pipeline, monitor performance, and debug issues quickly.

While the pipeline works end-to-end, to get a pulse on how it performs across multiple prompts and to gradually improve performance, we can start experimenting with prompt tuning.

Two things help make this process meaningful:

- **A clean evaluation dataset** - so we can measure accuracy against trusted ground truth.
- **A systematic evaluation loop** - so we can see whether prompt changes or other adjustments actually help.

With these in place, the next step is to build a "golden" QA dataset that will guide iterative prompt optimization.

## Building the QA dataset

> [!NOTE]
> The WikiTableQuestions dataset already includes question–answer pairs, available in its [GitHub repository](https://github.com/ppasupat/WikiTableQuestions/tree/master/data). To use them for this workflow, you'll need to adapt the data into the required format, but the raw material is there for you to build on.

We generate a dataset of natural language questions paired with executable SQL queries. This dataset acts as the benchmark for prompt tuning and evaluation.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

The pipeline does the following:

- Schema extraction – pull full database schemas, including table names, columns, and sample rows.
- Question–SQL generation – use an LLM to produce natural language questions with matching SQL queries.
- Validation – run each query against the database, filter out invalid results, and also remove results that aren't relevant.
- Final export – store the clean, validated pairs in CSV format for downstream use.

### Schema extraction and chunking

We break schemas into smaller chunks to cover all tables evenly. This avoids overfitting to a subset of tables and ensures broad coverage across the dataset.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

### Question and SQL generation

Using structured prompts, we ask an LLM to generate realistic questions users might ask, then pair them with syntactically valid SQL queries. Deduplication ensures diversity across queries.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

### Validation and quality control

Each generated SQL query runs against the database, and another LLM double-checks that the result matches the intent of the natural language question.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

Even with automated checks, human review remains critical. Since this dataset serves as the ground truth, mislabeled pairs can distort evaluation. For production use, always invest in human-in-the-loop review.

Support for human-in-the-loop pipelines is coming soon in Flyte 2!

## Optimizing prompts

With the QA dataset in place, we can turn to prompt optimization. The idea: start from a baseline prompt, generate new variants, and measure whether accuracy improves.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

### Evaluation pipeline

We evaluate each prompt variant against the golden dataset, split into validation and test sets, and record accuracy metrics in real time.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

Here's how prompt accuracy evolves over time, as shown in the UI report:

![Prompt accuracies](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/text-to-sql/prompt_accuracies.png)

### Iterative optimization

An optimizer LLM proposes new prompts by analyzing patterns in successful and failed generations. Each candidate runs through the evaluation loop, and we select the best performer.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

On paper, this creates a continuous improvement cycle: baseline → new variants → measured gains.

## Run it

To create the QA dataset:

```
python create_qa_dataset.py
```

To run the prompt optimization loop:

```
python optimizer.py
```

## What we observed

Prompt optimization didn't consistently lift SQL accuracy in this workflow. Accuracy plateaued near the baseline. But the process surfaced valuable lessons about what matters when building LLM-powered systems on real infrastructure.

- **Schema clarity matters**: CSV ingestion produced tables with overlapping names, creating ambiguity. This showed how schema design and metadata hygiene directly affect downstream evaluation.
- **Ground truth needs trust**: Because the dataset came from LLM outputs, noise remained even after filtering. Human review proved essential. Golden datasets need deliberate curation, not just automation.
- **Optimization needs context**: The optimizer couldn't “see” which examples failed, limiting its ability to improve. Feeding failures directly risks overfitting. A structured way to capture and reuse evaluation signals is the right long-term path.

Sometimes prompt tweaks alone can lift accuracy, but other times the real bottleneck lives in the data, the schema, or the evaluation loop. The lesson isn't "prompt optimization doesn't work", but that its impact depends on the system around it. Accuracy improves most reliably when prompts evolve alongside clean data, trusted evaluation, and observable feedback loops.

## The bigger lesson

Evaluation and optimization aren’t one-off experiments; they’re continuous processes. What makes them sustainable isn't a clever prompt, it’s the platform around it.

Systems succeed when they:

- **Observe** failures with clarity — track exactly what failed and why.
- **Remain durable** across iterations — run pipelines that are stable, reproducible, and comparable over time.

That's where Flyte 2 comes in. Prompt optimization is one lever, but it becomes powerful only when combined with:

- Clean, human-validated evaluation datasets.
- Systematic reporting and feedback loops.

**The real takeaway: improving LLM pipelines isn't about chasing the perfect prompt. It's about designing workflows with observability and durability at the core, so that every experiment compounds into long-term progress.**

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/context-engineering/auto_prompt_engineering ===

# Automatic prompt engineering

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/auto_prompt_engineering).

When building with LLMs and agents, the first prompt almost never works. We usually need several iterations before results are useful. Doing this manually is slow, inconsistent, and hard to reproduce.

Flyte turns prompt engineering into a systematic process. With Flyte we can:

- Generate candidate prompts automatically.
- Run evaluations in parallel.
- Track results in real time with built-in observability.
- Recover from failures without losing progress.
- Trace the lineage of every experiment for reproducibility.

And we're not limited to prompts. Just like [hyperparameter optimization](../../model-training/hpo/_index) in ML, we can tune model temperature, retrieval strategies, tool usage, and more. Over time, this grows into full agentic evaluations, tracking not only prompts but also how agents behave, make decisions, and interact with their environment.

In this tutorial, we'll build an automated prompt engineering pipeline with Flyte, step by step.

## Set up the environment

First, let's configure our task environment.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

We need an API key to call GPT-4.1 (our optimization model). Add it as a Flyte secret:

```
flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
```

We also define CSS styles for live HTML reports that track prompt optimization in real time:

![Results](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/results.gif)

## Prepare the evaluation dataset

Next, we define our golden dataset, a set of prompts with known outputs. This dataset is used to evaluate the quality of generated prompts.

For this tutorial, we use a small geometric shapes dataset. To keep it portable, the data prep task takes a CSV file (as a Flyte `File` or a string for files available remotely) and splits it into train and test subsets.

If you already have prompts and outputs in Google Sheets, simply export them as CSV with two columns: `input` and `target`.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

This approach works with any dataset. You can swap in your own with no extra dependencies.

## Define models

We use two models:

- **Target model** → the one we want to optimize.
- **Review model** → the one that evaluates candidate prompts.

First, we capture all model parameters in a dataclass:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

Then we define a Flyte `trace` to call the model. Unlike a task, a trace runs within the same runtime as the parent process. Since the model is hosted externally, this keeps the call lightweight but still observable.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

You can also host your own models on Union. For example, we deploy <code>gpt-oss-20b</code> using vLLM.

```
import union
from union.app.llm import VLLMApp
from flytekit.extras.accelerators import A10G

Model = union.Artifact(name="gpt-oss-20b")

image = union.ImageSpec(
    name="vllm-gpt-oss",
    builder="union",
    apt_packages=["build-essential", "wget", "gnupg"],
    packages=[
        "union[vllm]==0.1.191b0",
        "--pre vllm==0.10.1+gptoss \
        --extra-index-url https://wheels.vllm.ai/gpt-oss/ \
        --extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
        --index-strategy unsafe-best-match",
    ],
).with_commands(
    [
        "wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
        "dpkg -i cuda-keyring_1.1-1_all.deb",
        "apt-get update",
        "apt-get install -y cuda-toolkit-12-8",
        "/usr/local/cuda/bin/nvcc --version",
        "chown -R union /root",
        "chown -R union /home",
    ]
)

gpt_oss_app = VLLMApp(
    name="gpt-oss-20b-vllm",
    model=Model.query(),
    model_id="gpt-oss",
    container_image=image,
    requests=union.Resources(cpu="5", mem="26Gi", gpu="1", ephemeral_storage="150Gi"),
    accelerator=A10G,
    scaledown_after=300,
    stream_model=True,
    requires_auth=False,
    extra_args="--async-scheduling",
    env={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/gpt_oss.py*

<p>We use an <code>A10G</code> GPU instance, and with streaming, you can load model weights directly into GPU memory instead of downloading the weights to disk first, then loading to GPU memory.</p>

<p>To deploy the model, cache the model from HuggingFace with a Union artifact:</p>

<pre>
union cache model-from-hf \
    --hf-token-key hf-api-key \
    --artifact-name gpt-oss-20b \
    --cpu 2 \
    --mem 8Gi \
    --ephemeral-storage 100Gi openai/gpt-oss-20b
</pre>

Then deploy it:

<pre>
union deploy apps gpt_oss.py gpt-oss-20b-vllm
</pre>

When using a hosted model, just provide its <code>hosted_model_uri</code> in <code>ModelConfig</code>. All inference happens locally, so your data never leaves your environment.

Finally, we wrap the trace in a task to call both target and review models:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Evaluate prompts

We now define the evaluation process.

Each prompt in the dataset is tested in parallel, but we use a semaphore to control concurrency. A helper function ties together the `generate_and_review` task with an HTML report template. Using `asyncio.gather`, we evaluate multiple prompts at once.

The function measures accuracy as the fraction of responses that match the ground truth. Flyte streams these results to the UI, so you can watch evaluations happen live.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Optimize prompts

Optimization builds on evaluation. We give the optimizer model:

- the history of prompts tested so far, and
- their accuracies.

The model then proposes a new prompt.

We start with a _baseline_ evaluation using the user-provided prompt. Then for each iteration, the optimizer suggests a new prompt, which we evaluate and log. We continue until we hit the iteration limit.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

At the end, we return the best prompt and its accuracy. The report shows how accuracy improves over time and which prompts were tested.

![Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/prompt_accuracies.png)

## Build the full pipeline

The entrypoint task wires everything together:

- Accepts model configs, dataset, iteration count, and concurrency.
- Runs data preparation.
- Calls the optimizer.
- Evaluates both baseline and best prompts on the test set.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Run it

We add a simple main block so we can run the workflow as a script:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

Run it with:

```
uv run optimizer.py
```

![Execution](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/execution.gif)

## Why this matters

Most prompt engineering pipelines start as quick scripts or notebooks. They're fine for experimenting, but they're difficult to scale, reproduce, or debug when things go wrong.

With Flyte 2, we get a more reliable setup:

- Run many evaluations in parallel with [async Python](https://www.union.ai/docs/v2/union/user-guide/migration/flyte-2/async/page.md) or [native DSL](https://www.union.ai/docs/v2/union/user-guide/migration/flyte-2/async/page.md).
- Watch accuracy improve in real time and link results back to the exact dataset, prompt, and model config used.
- Resume cleanly after failures without rerunning everything from scratch.
- Reuse the same pattern to tune other parameters like temperature, retrieval depth, or agent strategies, not just prompts.

## Next steps

You now have a working automated prompt engineering pipeline. Here’s how you can take it further:

- **Optimize beyond prompts**: Tune temperature, retrieval strategies, or tool usage just like prompts.
- **Expand evaluation metrics**: Add latency, cost, robustness, or diversity alongside accuracy.
- **Move toward agentic evaluation**: Instead of single prompts, test how agents plan, use tools, and recover from failures in long-horizon tasks.

With this foundation, prompt engineering becomes repeatable, observable, and scalable, ready for production-grade LLM and agent systems.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/model-training ===

# Model training

Tutorials for training, fine-tuning, and hyperparameter optimization of models at scale.

### **Model training > Hyperparameter optimization**

Run large-scale HPO experiments with zero manual tracking, deterministic results, and automatic recovery.

### **Model training > LLM fine-tuning with LoRA and QLoRA**

Fine-tune a language model for SQL generation using full, LoRA, or QLoRA methods in one Flyte pipeline.

### **Model training > BERT emotion classification**

Fine-tune ModernBERT on Twitter emotion labels with confusion-matrix evaluation and attention visualizations.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/model-training/llm-fine-tuning-lora-qlora ===

# LLM fine-tuning with LoRA and QLoRA

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/llm_fine_tuning_lora_qlora).

This tutorial fine-tunes a language model for SQL generation using three methods in one workflow: **full** fine-tuning, **LoRA** adapters, and **QLoRA** (4-bit quantized base + LoRA). The pipeline prepares an instruction dataset from HuggingFace, trains with [TRL](https://huggingface.co/docs/trl) `SFTTrainer`, evaluates against a base-model baseline, and streams training charts into Flyte reports.

Flyte provides:

- **GPU training** with live loss and learning-rate charts via `report=True`.
- **Method switching** through a single `method` parameter (`full`, `lora`, or `qlora`).
- **Cached dataset preparation** for fast iteration on hyperparameters.

## Define the task environments

The GPU environment declares a HuggingFace token secret for gated models.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "peft>=0.13.0",
#    "trl>=0.12.0",
#    "datasets>=3.0.0",
#    "bitsandbytes>=0.44.0",
#    "accelerate>=0.34.0",
# ]
# main = "pipeline"
# params = ""
# ///
import asyncio
import json
import logging
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
import os

main_img = flyte.Image.from_uv_script(__file__, name="llm-fine-tuning-lora-qlora", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="llm-fine-tuning-lora-qlora-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
    secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)

cpu_env = flyte.TaskEnvironment(
    name="llm-fine-tuning-lora-qlora-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)

HF_TOKEN = os.environ.get("HF_TOKEN")
# {{/docs-fragment env}}

from report_helpers import make_bar_chart, make_line_chart, pipeline_step_indicator, wrap_report

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Task 1: Prepare dataset
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def prepare_data(
    dataset_name: str = "b-mc2/sql-create-context",
    max_train_samples: int = 5000,
    max_eval_samples: int = 500,
) -> flyte.io.Dir:
    """Download dataset from HuggingFace and format for instruction fine-tuning."""
    from datasets import DatasetDict, load_dataset

    log.info(f"Loading dataset: {dataset_name}")
    ds = load_dataset(dataset_name, split="train")

    def format_example(ex):
        return {
            "text": (
                "### Task: Generate a SQL query to answer the question.\n"
                f"### Schema:\n{ex['context']}\n"
                f"### Question:\n{ex['question']}\n"
                f"### SQL:\n{ex['answer']}\n<|endoftext|>"
            )
        }

    ds = ds.map(format_example)

    # Split into train and eval
    total = len(ds)
    train_end = min(max_train_samples, total - max_eval_samples)
    eval_start = train_end
    eval_end = min(eval_start + max_eval_samples, total)

    processed = DatasetDict({
        "train": ds.select(range(train_end)),
        "eval": ds.select(range(eval_start, eval_end)),
    })

    output_dir = os.path.join(tempfile.mkdtemp(), "dataset")
    processed.save_to_disk(output_dir)
    log.info(f"Dataset ready: {len(processed['train'])} train, {len(processed['eval'])} eval")

    return await flyte.io.Dir.from_local(output_dir)

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    method: str = "lora",
    epochs: int = 3,
    lr: float = 2e-4,
    batch_size: int = 4,
    lora_r: int = 16,
    lora_alpha: int = 32,
) -> flyte.io.Dir:
    """Fine-tune a model using full, LoRA, or QLoRA method."""
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
    from trl import SFTConfig, SFTTrainer

    log.info(f"Training: model={model_name}, method={method}")

    # -- Load data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)

    # -- Load tokenizer --
    token_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
    tokenizer = AutoTokenizer.from_pretrained(model_name, **token_kwargs)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # -- Initial report: loading model --
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Loading Model...</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="card">'
            f"<p><b>Method:</b> <span class=\"badge badge-info\">{method.upper()}</span></p>"
            f"<p><b>Dataset:</b> {len(dataset['train']):,} train / {len(dataset['eval']):,} eval</p>"
            f"</div>"
        ),
        do_flush=True,
    )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float32

    if method == "qlora":
        from transformers import BitsAndBytesConfig

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **token_kwargs,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
            ),
            dtype=dtype,
            device_map="auto",
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **token_kwargs,
            dtype=dtype,
            device_map="auto",
        )

    # -- Apply LoRA adapters --
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    if method in ("lora", "qlora"):
        from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

        if method == "qlora":
            model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
            r=lora_r,              # Rank — size of the low-rank matrices. Higher = more capacity but more params
            lora_alpha=lora_alpha,  # Scaling factor — controls adapter impact. Effective scale = alpha/r
            # Attention layers — LoRA adapters inject low-rank updates here:
            #   q_proj (Query)     — what to look for in context
            #   k_proj (Key)       — what each token offers to match against
            #   v_proj (Value)     — what information to extract once matched
            #   o_proj (Output)    — combines multi-head attention results
            # MLP layers — LoRA adapters also update the feed-forward network:
            #   gate_proj (Gate)   — controls how much information flows through (SwiGLU activation)
            #   up_proj (Up)       — projects to a higher dimension for richer representations
            #   down_proj (Down)   — projects back down to the model's hidden size
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,     # Dropout on adapter weights — light regularization to prevent overfitting
            bias="none",           # Don't train bias terms — keeps adapter small and stable
            task_type="CAUSAL_LM", # Tells PEFT this is a text generation model (vs classification, etc.)
        )
        model = get_peft_model(model, lora_config)
        trainable_params, total_params = model.get_nb_trainable_parameters()
        log.info(f"Trainable params: {trainable_params:,} / {total_params:,} ({trainable_params / total_params * 100:.1f}%)")

    # -- Live training report state --
    training_log: list[dict] = []
    loop = asyncio.get_running_loop()

    method_badge = f'<span class="badge badge-info">{method.upper()}</span>'
    if method == "qlora":
        method_badge = f'<span class="badge badge-success">QLoRA (4-bit)</span>'
    elif method == "full":
        method_badge = f'<span class="badge badge-danger">Full Fine-Tune</span>'

    def _build_training_report(max_steps: int) -> str:
        """Build the live training report HTML from current training_log."""
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>
          <div class="stat"><div class="value">{len(dataset['train']):,}</div><div class="label">Train Examples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
        </div>
        <p>Method: {method_badge} | Total params: {total_params:,} | Trainable: {trainable_params:,}</p>
        """

        charts_html = ""
        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs} |
              Loss: <span class="highlight">{current['loss']:.4f}</span>
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_chart = make_line_chart(
                data=training_log,
                x_key="epoch",
                y_keys=["loss"],
                title="Training Loss",
                x_label="Epoch",
                y_label="Loss",
                colors=["#5a7db5"],
            )
            charts_html += f'<div class="chart-container">{loss_chart}</div>'

            if "lr" in training_log[0]:
                lr_chart = make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["lr"],
                    title="Learning Rate Schedule",
                    x_label="Epoch",
                    y_label="LR",
                    colors=["#0f3460"],
                )
                charts_html += f'<div class="chart-container">{lr_chart}</div>'

            if "grad_norm" in training_log[0]:
                grad_chart = make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["grad_norm"],
                    title="Gradient Norm",
                    x_label="Epoch",
                    y_label="Grad Norm",
                    colors=["#06d6a0"],
                )
                charts_html += f'<div class="chart-container">{grad_chart}</div>'

        return wrap_report(stats_html + charts_html)

    # -- Metrics callback with live report updates --
    class MetricsCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs or "loss" not in logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
                "loss": round(logs["loss"], 4),
            }
            if "learning_rate" in logs:
                entry["lr"] = logs["learning_rate"]
            if "grad_norm" in logs:
                entry["grad_norm"] = round(float(logs["grad_norm"]), 4)
            training_log.append(entry)
            log.info(
                f"step={state.global_step}/{state.max_steps} "
                f"epoch={entry['epoch']:.2f} "
                f"loss={entry['loss']:.4f}"
            )

            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

    # -- Train --
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=lr,
        logging_steps=10,
        save_strategy="epoch",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        gradient_accumulation_steps=4,
        warmup_steps=10,
        report_to="none",
    )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        processing_class=tokenizer,
        callbacks=[MetricsCallback()],
    )

    log.info("Starting training...")
    await asyncio.to_thread(trainer.train)
    log.info("Training complete.")

    # -- Merge LoRA weights and save --
    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")

    if method in ("lora", "qlora"):
        log.info("Merging LoRA weights into base model...")
        model = model.merge_and_unload()

    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Final training report --
    final_loss = training_log[-1]["loss"] if training_log else "N/A"

    loss_chart = make_line_chart(
        data=training_log,
        x_key="epoch",
        y_keys=["loss"],
        title="Training Loss",
        x_label="Epoch",
        y_label="Loss",
        colors=["#5a7db5"],
    ) if training_log else ""

    lr_chart = ""
    if training_log and "lr" in training_log[0]:
        lr_chart = make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["lr"],
            title="Learning Rate Schedule",
            x_label="Epoch",
            y_label="LR",
            colors=["#0f3460"],
        )

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Training Complete</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>'
            f'  <div class="stat"><div class="value">{final_loss}</div><div class="label">Final Loss</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Trainable Params</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">% Trainable</div></div>'
            f'</div>'
            f'<div class="chart-container">{loss_chart}</div>'
            f'{f"""<div class="chart-container">{lr_chart}</div>""" if lr_chart else ""}'
        ),
        do_flush=True,
    )

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Task 3: Evaluate — before/after comparison
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    model_name: str,
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 50,
) -> str:
    """Compare base model vs fine-tuned model on test examples."""
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(
        wrap_report(
            "<h2>Evaluation</h2>"
            '<div class="card"><p>Loading models and running inference...</p></div>'
        ),
        do_flush=True,
    )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float32

    # Load eval data
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"].select(range(min(num_examples, len(dataset["eval"]))))

    # Load tokenizer
    token_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
    tokenizer = AutoTokenizer.from_pretrained(model_name, **token_kwargs)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def generate_sql(model, prompt, max_new_tokens=128):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
            )
        return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()

    def normalize_sql(sql):
        """Extract the first SQL statement and normalize for comparison."""
        # Truncate at first ### or newline to isolate the SQL
        for stop in ["###", "\n"]:
            if stop in sql:
                sql = sql[:sql.index(stop)]
        return " ".join(sql.lower().split()).strip().rstrip(";")

    def build_prompt(example):
        return (
            "### Task: Generate a SQL query to answer the question.\n"
            f"### Schema:\n{example['context']}\n"
            f"### Question:\n{example['question']}\n"
            "### SQL:\n"
        )

    # -- Run base model --
    log.info(f"Loading base model: {model_name}")
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{len(eval_ds)}</div><div class="label">Eval Examples</div></div>'
            f'  <div class="stat"><div class="value">1/2</div><div class="label">Phase</div></div>'
            f'</div>'
            f'<div class="card"><p>Running <b>base model</b> inference...</p>'
            f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
            f'<div style="background:#adb5bd;width:25%;height:100%;border-radius:4px;"></div>'
            f'</div></div>'
        ),
        do_flush=True,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        model_name, **token_kwargs, dtype=dtype, device_map="auto",
    )

    base_results = []
    for i, example in enumerate(eval_ds):
        prompt = build_prompt(example)
        generated = generate_sql(base_model, prompt)
        base_results.append(generated)
        if (i + 1) % 10 == 0:
            log.info(f"Base model: {i + 1}/{len(eval_ds)}")
            pct = (i + 1) / len(eval_ds) * 50
            await flyte.report.replace.aio(
                wrap_report(
                    f"<h2>Evaluation</h2>"
                    f'<div class="card"><p>Running <b>base model</b> inference... {i + 1}/{len(eval_ds)}</p>'
                    f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                    f'<div style="background:#adb5bd;width:{pct:.0f}%;height:100%;border-radius:4px;"></div>'
                    f'</div></div>'
                ),
                do_flush=True,
            )

    del base_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Run fine-tuned model --
    log.info("Loading fine-tuned model...")
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation</h2>"
            f'<div class="card"><p>Running <b>fine-tuned model</b> inference...</p>'
            f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
            f'<div style="background:#0f3460;width:50%;height:100%;border-radius:4px;"></div>'
            f'</div></div>'
        ),
        do_flush=True,
    )

    ft_path = await finetuned_dir.download()
    ft_model = AutoModelForCausalLM.from_pretrained(
        ft_path, dtype=dtype, device_map="auto",
    )

    ft_results = []
    for i, example in enumerate(eval_ds):
        prompt = build_prompt(example)
        generated = generate_sql(ft_model, prompt)
        ft_results.append(generated)
        if (i + 1) % 10 == 0:
            log.info(f"Fine-tuned model: {i + 1}/{len(eval_ds)}")
            pct = 50 + (i + 1) / len(eval_ds) * 50
            await flyte.report.replace.aio(
                wrap_report(
                    f"<h2>Evaluation</h2>"
                    f'<div class="card"><p>Running <b>fine-tuned model</b> inference... {i + 1}/{len(eval_ds)}</p>'
                    f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                    f'<div style="background:#0f3460;width:{pct:.0f}%;height:100%;border-radius:4px;"></div>'
                    f'</div></div>'
                ),
                do_flush=True,
            )

    del ft_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Score --
    base_correct = 0
    ft_correct = 0
    comparisons = []

    for i, example in enumerate(eval_ds):
        expected = example["answer"]
        base_gen = base_results[i]
        ft_gen = ft_results[i]

        base_match = normalize_sql(base_gen) == normalize_sql(expected)
        ft_match = normalize_sql(ft_gen) == normalize_sql(expected)

        if base_match:
            base_correct += 1
        if ft_match:
            ft_correct += 1

        comparisons.append({
            "question": example["question"],
            "schema": example["context"],
            "expected": expected,
            "base": base_gen,
            "finetuned": ft_gen,
            "base_correct": base_match,
            "ft_correct": ft_match,
        })

    total = len(eval_ds)
    base_acc = base_correct / total * 100
    ft_acc = ft_correct / total * 100
    improvement = ft_acc - base_acc

    log.info(f"Base model accuracy: {base_acc:.1f}% ({base_correct}/{total})")
    log.info(f"Fine-tuned accuracy: {ft_acc:.1f}% ({ft_correct}/{total})")

    # -- Build final eval report --
    improvement_badge = (
        f'<span class="badge badge-success">+{improvement:.1f}pp</span>'
        if improvement > 0
        else f'<span class="badge badge-danger">{improvement:.1f}pp</span>'
    )

    bar_chart = make_bar_chart(
        labels=["Exact Match Accuracy"],
        series={
            "Base Model": [base_acc],
            "Fine-Tuned": [ft_acc],
        },
        title="Base vs Fine-Tuned Accuracy",
        colors=["#adb5bd", "#0f3460"],
        y_max_cap=100.0,
    )

    examples_html = ""
    for c in comparisons[:10]:
        base_badge = '<span class="badge badge-success">correct</span>' if c["base_correct"] else '<span class="badge badge-danger">wrong</span>'
        ft_badge = '<span class="badge badge-success">correct</span>' if c["ft_correct"] else '<span class="badge badge-danger">wrong</span>'
        examples_html += f"""
        <div class="card">
          <p><b>Q:</b> {c['question']}</p>
          <p style="font-size:0.85em; color:#6c757d;"><b>Schema:</b> {c['schema'][:200]}...</p>
          <table>
            <tr><th>Source</th><th>SQL</th><th>Result</th></tr>
            <tr><td>Expected</td><td><code>{c['expected']}</code></td><td></td></tr>
            <tr><td>Base</td><td><code>{c['base'][:200]}</code></td><td>{base_badge}</td></tr>
            <tr><td>Fine-tuned</td><td><code>{c['finetuned'][:200]}</code></td><td>{ft_badge}</td></tr>
          </table>
        </div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation Results</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{base_acc:.1f}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{ft_acc:.1f}%</div><div class="label">Fine-Tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{improvement:+.1f}pp</div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{total}</div><div class="label">Eval Examples</div></div>'
            f'</div>'
            f'<div class="chart-container">{bar_chart}</div>'
            f'<h3>Example Comparisons {improvement_badge}</h3>'
            f'{examples_html}'
            f'<div class="note">'
            f'<b>Note:</b> Exact match accuracy compares normalized SQL output. '
            f'The fine-tuned model may generate semantically correct queries that differ in formatting.'
            f'</div>'
        ),
        do_flush=True,
    )

    return json.dumps({
        "base_accuracy": round(base_acc, 1),
        "finetuned_accuracy": round(ft_acc, 1),
        "improvement": round(ft_acc - base_acc, 1),
        "num_examples": total,
        "comparisons": comparisons[:10],
    })

# ------------------------------------------------------------------
# Pipeline: orchestrate everything
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "HuggingFaceTB/SmolLM2-135M",
    dataset_name: str = "b-mc2/sql-create-context",
    method: str = "lora",
    epochs: int = 3,
    lr: float = 2e-4,
    batch_size: int = 4,
    max_train_samples: int = 5000,
    max_eval_samples: int = 500,
    num_eval_examples: int = 50,
    lora_r: int = 16,
    lora_alpha: int = 32,
) -> flyte.io.Dir:
    """
    End-to-end LLM fine-tuning pipeline.

    1. Download and format dataset
    2. Fine-tune model (full / LoRA / QLoRA)
    3. Evaluate: before/after comparison on test set

    Returns the fine-tuned model directory so it can be served directly.
    """
    log.info(f"Pipeline: {model_name} | method={method} | dataset={dataset_name}")
    steps = ["Prepare Data", "Train", "Evaluate"]

    method_badge = f'<span class="badge badge-info">{method.upper()}</span>'

    # Step 1: Prepare data
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(0, steps)}'
            f'<div class="card"><p>Downloading and formatting dataset: <b>{dataset_name}</b>...</p></div>'
        ),
        do_flush=True,
    )

    data_dir = await prepare_data(dataset_name, max_train_samples, max_eval_samples)

    # Step 2: Train
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(1, steps)}'
            f'<div class="card"><p>Training in progress... check the <b>train</b> task report for live charts.</p></div>'
        ),
        do_flush=True,
    )

    finetuned_dir = await train(
        model_name, data_dir, method, epochs, lr, batch_size, lora_r, lora_alpha,
    )

    # Step 3: Evaluate
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(2, steps)}'
            f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
        ),
        do_flush=True,
    )

    result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
    metrics = json.loads(result)

    # Final pipeline report
    improvement = metrics["improvement"]
    improvement_badge = (
        f'<span class="badge badge-success">+{improvement:.1f}pp</span>'
        if improvement > 0
        else f'<span class="badge badge-danger">{improvement:.1f}pp</span>'
    )

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Pipeline Complete</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(3, steps)}'
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{metrics["finetuned_accuracy"]}%</div><div class="label">Fine-Tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{improvement:+.1f}pp</div><div class="label">Improvement {improvement_badge}</div></div>'
            f'  <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{metrics["num_examples"]}</div><div class="label">Eval Examples</div></div>'
            f'</div>'
            f'<div class="note">'
            f'Check the <b>train</b> task report for training loss/LR charts, '
            f'and the <b>evaluate</b> task report for detailed example comparisons.'
            f'</div>'
        ),
        do_flush=True,
    )

    log.info(f"Pipeline complete. Improvement: {metrics['improvement']:+.1f}pp")
    return finetuned_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/llm_fine_tuning_lora_qlora/llm_fine_tuning_lora_qlora.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "peft>=0.13.0",
#    "trl>=0.12.0",
#    "bitsandbytes>=0.44.0",
#    ...
# ]
# ///
```

## Orchestrate the pipeline

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "peft>=0.13.0",
#    "trl>=0.12.0",
#    "datasets>=3.0.0",
#    "bitsandbytes>=0.44.0",
#    "accelerate>=0.34.0",
# ]
# main = "pipeline"
# params = ""
# ///
import asyncio
import json
import logging
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
import os

main_img = flyte.Image.from_uv_script(__file__, name="llm-fine-tuning-lora-qlora", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="llm-fine-tuning-lora-qlora-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="24Gi", gpu=1),
    secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)

cpu_env = flyte.TaskEnvironment(
    name="llm-fine-tuning-lora-qlora-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)

HF_TOKEN = os.environ.get("HF_TOKEN")
# {{/docs-fragment env}}

from report_helpers import make_bar_chart, make_line_chart, pipeline_step_indicator, wrap_report

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# ------------------------------------------------------------------
# Task 1: Prepare dataset
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def prepare_data(
    dataset_name: str = "b-mc2/sql-create-context",
    max_train_samples: int = 5000,
    max_eval_samples: int = 500,
) -> flyte.io.Dir:
    """Download dataset from HuggingFace and format for instruction fine-tuning."""
    from datasets import DatasetDict, load_dataset

    log.info(f"Loading dataset: {dataset_name}")
    ds = load_dataset(dataset_name, split="train")

    def format_example(ex):
        return {
            "text": (
                "### Task: Generate a SQL query to answer the question.\n"
                f"### Schema:\n{ex['context']}\n"
                f"### Question:\n{ex['question']}\n"
                f"### SQL:\n{ex['answer']}\n<|endoftext|>"
            )
        }

    ds = ds.map(format_example)

    # Split into train and eval
    total = len(ds)
    train_end = min(max_train_samples, total - max_eval_samples)
    eval_start = train_end
    eval_end = min(eval_start + max_eval_samples, total)

    processed = DatasetDict({
        "train": ds.select(range(train_end)),
        "eval": ds.select(range(eval_start, eval_end)),
    })

    output_dir = os.path.join(tempfile.mkdtemp(), "dataset")
    processed.save_to_disk(output_dir)
    log.info(f"Dataset ready: {len(processed['train'])} train, {len(processed['eval'])} eval")

    return await flyte.io.Dir.from_local(output_dir)

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    method: str = "lora",
    epochs: int = 3,
    lr: float = 2e-4,
    batch_size: int = 4,
    lora_r: int = 16,
    lora_alpha: int = 32,
) -> flyte.io.Dir:
    """Fine-tune a model using full, LoRA, or QLoRA method."""
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
    from trl import SFTConfig, SFTTrainer

    log.info(f"Training: model={model_name}, method={method}")

    # -- Load data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)

    # -- Load tokenizer --
    token_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
    tokenizer = AutoTokenizer.from_pretrained(model_name, **token_kwargs)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # -- Initial report: loading model --
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Loading Model...</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="card">'
            f"<p><b>Method:</b> <span class=\"badge badge-info\">{method.upper()}</span></p>"
            f"<p><b>Dataset:</b> {len(dataset['train']):,} train / {len(dataset['eval']):,} eval</p>"
            f"</div>"
        ),
        do_flush=True,
    )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float32

    if method == "qlora":
        from transformers import BitsAndBytesConfig

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **token_kwargs,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
            ),
            dtype=dtype,
            device_map="auto",
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **token_kwargs,
            dtype=dtype,
            device_map="auto",
        )

    # -- Apply LoRA adapters --
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    if method in ("lora", "qlora"):
        from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

        if method == "qlora":
            model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
            r=lora_r,              # Rank — size of the low-rank matrices. Higher = more capacity but more params
            lora_alpha=lora_alpha,  # Scaling factor — controls adapter impact. Effective scale = alpha/r
            # Attention layers — LoRA adapters inject low-rank updates here:
            #   q_proj (Query)     — what to look for in context
            #   k_proj (Key)       — what each token offers to match against
            #   v_proj (Value)     — what information to extract once matched
            #   o_proj (Output)    — combines multi-head attention results
            # MLP layers — LoRA adapters also update the feed-forward network:
            #   gate_proj (Gate)   — controls how much information flows through (SwiGLU activation)
            #   up_proj (Up)       — projects to a higher dimension for richer representations
            #   down_proj (Down)   — projects back down to the model's hidden size
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,     # Dropout on adapter weights — light regularization to prevent overfitting
            bias="none",           # Don't train bias terms — keeps adapter small and stable
            task_type="CAUSAL_LM", # Tells PEFT this is a text generation model (vs classification, etc.)
        )
        model = get_peft_model(model, lora_config)
        trainable_params, total_params = model.get_nb_trainable_parameters()
        log.info(f"Trainable params: {trainable_params:,} / {total_params:,} ({trainable_params / total_params * 100:.1f}%)")

    # -- Live training report state --
    training_log: list[dict] = []
    loop = asyncio.get_running_loop()

    method_badge = f'<span class="badge badge-info">{method.upper()}</span>'
    if method == "qlora":
        method_badge = f'<span class="badge badge-success">QLoRA (4-bit)</span>'
    elif method == "full":
        method_badge = f'<span class="badge badge-danger">Full Fine-Tune</span>'

    def _build_training_report(max_steps: int) -> str:
        """Build the live training report HTML from current training_log."""
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>
          <div class="stat"><div class="value">{len(dataset['train']):,}</div><div class="label">Train Examples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">Trainable</div></div>
        </div>
        <p>Method: {method_badge} | Total params: {total_params:,} | Trainable: {trainable_params:,}</p>
        """

        charts_html = ""
        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs} |
              Loss: <span class="highlight">{current['loss']:.4f}</span>
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_chart = make_line_chart(
                data=training_log,
                x_key="epoch",
                y_keys=["loss"],
                title="Training Loss",
                x_label="Epoch",
                y_label="Loss",
                colors=["#5a7db5"],
            )
            charts_html += f'<div class="chart-container">{loss_chart}</div>'

            if "lr" in training_log[0]:
                lr_chart = make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["lr"],
                    title="Learning Rate Schedule",
                    x_label="Epoch",
                    y_label="LR",
                    colors=["#0f3460"],
                )
                charts_html += f'<div class="chart-container">{lr_chart}</div>'

            if "grad_norm" in training_log[0]:
                grad_chart = make_line_chart(
                    data=training_log,
                    x_key="epoch",
                    y_keys=["grad_norm"],
                    title="Gradient Norm",
                    x_label="Epoch",
                    y_label="Grad Norm",
                    colors=["#06d6a0"],
                )
                charts_html += f'<div class="chart-container">{grad_chart}</div>'

        return wrap_report(stats_html + charts_html)

    # -- Metrics callback with live report updates --
    class MetricsCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs or "loss" not in logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
                "loss": round(logs["loss"], 4),
            }
            if "learning_rate" in logs:
                entry["lr"] = logs["learning_rate"]
            if "grad_norm" in logs:
                entry["grad_norm"] = round(float(logs["grad_norm"]), 4)
            training_log.append(entry)
            log.info(
                f"step={state.global_step}/{state.max_steps} "
                f"epoch={entry['epoch']:.2f} "
                f"loss={entry['loss']:.4f}"
            )

            asyncio.run_coroutine_threadsafe(
                flyte.report.replace.aio(
                    _build_training_report(state.max_steps),
                    do_flush=True,
                ),
                loop,
            )

    # -- Train --
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=lr,
        logging_steps=10,
        save_strategy="epoch",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        gradient_accumulation_steps=4,
        warmup_steps=10,
        report_to="none",
    )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        processing_class=tokenizer,
        callbacks=[MetricsCallback()],
    )

    log.info("Starting training...")
    await asyncio.to_thread(trainer.train)
    log.info("Training complete.")

    # -- Merge LoRA weights and save --
    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")

    if method in ("lora", "qlora"):
        log.info("Merging LoRA weights into base model...")
        model = model.merge_and_unload()

    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Final training report --
    final_loss = training_log[-1]["loss"] if training_log else "N/A"

    loss_chart = make_line_chart(
        data=training_log,
        x_key="epoch",
        y_keys=["loss"],
        title="Training Loss",
        x_label="Epoch",
        y_label="Loss",
        colors=["#5a7db5"],
    ) if training_log else ""

    lr_chart = ""
    if training_log and "lr" in training_log[0]:
        lr_chart = make_line_chart(
            data=training_log,
            x_key="epoch",
            y_keys=["lr"],
            title="Learning Rate Schedule",
            x_label="Epoch",
            y_label="LR",
            colors=["#0f3460"],
        )

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Training Complete</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>'
            f'  <div class="stat"><div class="value">{final_loss}</div><div class="label">Final Loss</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{total_params:,}</div><div class="label">Total Params</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Trainable Params</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params / total_params * 100:.1f}%</div><div class="label">% Trainable</div></div>'
            f'</div>'
            f'<div class="chart-container">{loss_chart}</div>'
            f'{f"""<div class="chart-container">{lr_chart}</div>""" if lr_chart else ""}'
        ),
        do_flush=True,
    )

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Task 3: Evaluate — before/after comparison
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    model_name: str,
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 50,
) -> str:
    """Compare base model vs fine-tuned model on test examples."""
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForCausalLM, AutoTokenizer

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(
        wrap_report(
            "<h2>Evaluation</h2>"
            '<div class="card"><p>Loading models and running inference...</p></div>'
        ),
        do_flush=True,
    )

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float32

    # Load eval data
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"].select(range(min(num_examples, len(dataset["eval"]))))

    # Load tokenizer
    token_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
    tokenizer = AutoTokenizer.from_pretrained(model_name, **token_kwargs)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def generate_sql(model, prompt, max_new_tokens=128):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
            )
        return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()

    def normalize_sql(sql):
        """Extract the first SQL statement and normalize for comparison."""
        # Truncate at first ### or newline to isolate the SQL
        for stop in ["###", "\n"]:
            if stop in sql:
                sql = sql[:sql.index(stop)]
        return " ".join(sql.lower().split()).strip().rstrip(";")

    def build_prompt(example):
        return (
            "### Task: Generate a SQL query to answer the question.\n"
            f"### Schema:\n{example['context']}\n"
            f"### Question:\n{example['question']}\n"
            "### SQL:\n"
        )

    # -- Run base model --
    log.info(f"Loading base model: {model_name}")
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{len(eval_ds)}</div><div class="label">Eval Examples</div></div>'
            f'  <div class="stat"><div class="value">1/2</div><div class="label">Phase</div></div>'
            f'</div>'
            f'<div class="card"><p>Running <b>base model</b> inference...</p>'
            f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
            f'<div style="background:#adb5bd;width:25%;height:100%;border-radius:4px;"></div>'
            f'</div></div>'
        ),
        do_flush=True,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        model_name, **token_kwargs, dtype=dtype, device_map="auto",
    )

    base_results = []
    for i, example in enumerate(eval_ds):
        prompt = build_prompt(example)
        generated = generate_sql(base_model, prompt)
        base_results.append(generated)
        if (i + 1) % 10 == 0:
            log.info(f"Base model: {i + 1}/{len(eval_ds)}")
            pct = (i + 1) / len(eval_ds) * 50
            await flyte.report.replace.aio(
                wrap_report(
                    f"<h2>Evaluation</h2>"
                    f'<div class="card"><p>Running <b>base model</b> inference... {i + 1}/{len(eval_ds)}</p>'
                    f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                    f'<div style="background:#adb5bd;width:{pct:.0f}%;height:100%;border-radius:4px;"></div>'
                    f'</div></div>'
                ),
                do_flush=True,
            )

    del base_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Run fine-tuned model --
    log.info("Loading fine-tuned model...")
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation</h2>"
            f'<div class="card"><p>Running <b>fine-tuned model</b> inference...</p>'
            f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
            f'<div style="background:#0f3460;width:50%;height:100%;border-radius:4px;"></div>'
            f'</div></div>'
        ),
        do_flush=True,
    )

    ft_path = await finetuned_dir.download()
    ft_model = AutoModelForCausalLM.from_pretrained(
        ft_path, dtype=dtype, device_map="auto",
    )

    ft_results = []
    for i, example in enumerate(eval_ds):
        prompt = build_prompt(example)
        generated = generate_sql(ft_model, prompt)
        ft_results.append(generated)
        if (i + 1) % 10 == 0:
            log.info(f"Fine-tuned model: {i + 1}/{len(eval_ds)}")
            pct = 50 + (i + 1) / len(eval_ds) * 50
            await flyte.report.replace.aio(
                wrap_report(
                    f"<h2>Evaluation</h2>"
                    f'<div class="card"><p>Running <b>fine-tuned model</b> inference... {i + 1}/{len(eval_ds)}</p>'
                    f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                    f'<div style="background:#0f3460;width:{pct:.0f}%;height:100%;border-radius:4px;"></div>'
                    f'</div></div>'
                ),
                do_flush=True,
            )

    del ft_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Score --
    base_correct = 0
    ft_correct = 0
    comparisons = []

    for i, example in enumerate(eval_ds):
        expected = example["answer"]
        base_gen = base_results[i]
        ft_gen = ft_results[i]

        base_match = normalize_sql(base_gen) == normalize_sql(expected)
        ft_match = normalize_sql(ft_gen) == normalize_sql(expected)

        if base_match:
            base_correct += 1
        if ft_match:
            ft_correct += 1

        comparisons.append({
            "question": example["question"],
            "schema": example["context"],
            "expected": expected,
            "base": base_gen,
            "finetuned": ft_gen,
            "base_correct": base_match,
            "ft_correct": ft_match,
        })

    total = len(eval_ds)
    base_acc = base_correct / total * 100
    ft_acc = ft_correct / total * 100
    improvement = ft_acc - base_acc

    log.info(f"Base model accuracy: {base_acc:.1f}% ({base_correct}/{total})")
    log.info(f"Fine-tuned accuracy: {ft_acc:.1f}% ({ft_correct}/{total})")

    # -- Build final eval report --
    improvement_badge = (
        f'<span class="badge badge-success">+{improvement:.1f}pp</span>'
        if improvement > 0
        else f'<span class="badge badge-danger">{improvement:.1f}pp</span>'
    )

    bar_chart = make_bar_chart(
        labels=["Exact Match Accuracy"],
        series={
            "Base Model": [base_acc],
            "Fine-Tuned": [ft_acc],
        },
        title="Base vs Fine-Tuned Accuracy",
        colors=["#adb5bd", "#0f3460"],
        y_max_cap=100.0,
    )

    examples_html = ""
    for c in comparisons[:10]:
        base_badge = '<span class="badge badge-success">correct</span>' if c["base_correct"] else '<span class="badge badge-danger">wrong</span>'
        ft_badge = '<span class="badge badge-success">correct</span>' if c["ft_correct"] else '<span class="badge badge-danger">wrong</span>'
        examples_html += f"""
        <div class="card">
          <p><b>Q:</b> {c['question']}</p>
          <p style="font-size:0.85em; color:#6c757d;"><b>Schema:</b> {c['schema'][:200]}...</p>
          <table>
            <tr><th>Source</th><th>SQL</th><th>Result</th></tr>
            <tr><td>Expected</td><td><code>{c['expected']}</code></td><td></td></tr>
            <tr><td>Base</td><td><code>{c['base'][:200]}</code></td><td>{base_badge}</td></tr>
            <tr><td>Fine-tuned</td><td><code>{c['finetuned'][:200]}</code></td><td>{ft_badge}</td></tr>
          </table>
        </div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation Results</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{base_acc:.1f}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{ft_acc:.1f}%</div><div class="label">Fine-Tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{improvement:+.1f}pp</div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{total}</div><div class="label">Eval Examples</div></div>'
            f'</div>'
            f'<div class="chart-container">{bar_chart}</div>'
            f'<h3>Example Comparisons {improvement_badge}</h3>'
            f'{examples_html}'
            f'<div class="note">'
            f'<b>Note:</b> Exact match accuracy compares normalized SQL output. '
            f'The fine-tuned model may generate semantically correct queries that differ in formatting.'
            f'</div>'
        ),
        do_flush=True,
    )

    return json.dumps({
        "base_accuracy": round(base_acc, 1),
        "finetuned_accuracy": round(ft_acc, 1),
        "improvement": round(ft_acc - base_acc, 1),
        "num_examples": total,
        "comparisons": comparisons[:10],
    })

# ------------------------------------------------------------------
# Pipeline: orchestrate everything
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "HuggingFaceTB/SmolLM2-135M",
    dataset_name: str = "b-mc2/sql-create-context",
    method: str = "lora",
    epochs: int = 3,
    lr: float = 2e-4,
    batch_size: int = 4,
    max_train_samples: int = 5000,
    max_eval_samples: int = 500,
    num_eval_examples: int = 50,
    lora_r: int = 16,
    lora_alpha: int = 32,
) -> flyte.io.Dir:
    """
    End-to-end LLM fine-tuning pipeline.

    1. Download and format dataset
    2. Fine-tune model (full / LoRA / QLoRA)
    3. Evaluate: before/after comparison on test set

    Returns the fine-tuned model directory so it can be served directly.
    """
    log.info(f"Pipeline: {model_name} | method={method} | dataset={dataset_name}")
    steps = ["Prepare Data", "Train", "Evaluate"]

    method_badge = f'<span class="badge badge-info">{method.upper()}</span>'

    # Step 1: Prepare data
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(0, steps)}'
            f'<div class="card"><p>Downloading and formatting dataset: <b>{dataset_name}</b>...</p></div>'
        ),
        do_flush=True,
    )

    data_dir = await prepare_data(dataset_name, max_train_samples, max_eval_samples)

    # Step 2: Train
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(1, steps)}'
            f'<div class="card"><p>Training in progress... check the <b>train</b> task report for live charts.</p></div>'
        ),
        do_flush=True,
    )

    finetuned_dir = await train(
        model_name, data_dir, method, epochs, lr, batch_size, lora_r, lora_alpha,
    )

    # Step 3: Evaluate
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>LLM Fine-Tuning Pipeline</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(2, steps)}'
            f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
        ),
        do_flush=True,
    )

    result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
    metrics = json.loads(result)

    # Final pipeline report
    improvement = metrics["improvement"]
    improvement_badge = (
        f'<span class="badge badge-success">+{improvement:.1f}pp</span>'
        if improvement > 0
        else f'<span class="badge badge-danger">{improvement:.1f}pp</span>'
    )

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Pipeline Complete</h2>"
            f"<h3>{model_name} {method_badge}</h3>"
            f'{pipeline_step_indicator(3, steps)}'
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{metrics["finetuned_accuracy"]}%</div><div class="label">Fine-Tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{improvement:+.1f}pp</div><div class="label">Improvement {improvement_badge}</div></div>'
            f'  <div class="stat"><div class="value">{method.upper()}</div><div class="label">Method</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{metrics["num_examples"]}</div><div class="label">Eval Examples</div></div>'
            f'</div>'
            f'<div class="note">'
            f'Check the <b>train</b> task report for training loss/LR charts, '
            f'and the <b>evaluate</b> task report for detailed example comparisons.'
            f'</div>'
        ),
        do_flush=True,
    )

    log.info(f"Pipeline complete. Improvement: {metrics['improvement']:+.1f}pp")
    return finetuned_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/llm_fine_tuning_lora_qlora/llm_fine_tuning_lora_qlora.py*

## Run the workflow

Create a HuggingFace token secret if you use a gated base model:

```
flyte create secret huggingface-token <YOUR_HF_TOKEN>
```

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/llm_fine_tuning_lora_qlora):

```
cd v2/tutorials/llm_fine_tuning_lora_qlora
uv run --script llm_fine_tuning_lora_qlora.py
```

Try QLoRA on a GPU:

```
flyte run llm_fine_tuning_lora_qlora.py pipeline --method qlora --epochs 3
```

QLoRA requires CUDA; LoRA and full fine-tuning follow the same entry point with different memory requirements.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/model-training/bert-fine-tuning-emotion ===

# BERT emotion classification

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/bert_fine_tuning_emotion).

This tutorial fine-tunes a BERT-style model (ModernBERT by default) on the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) Twitter dataset for six-way emotion classification: sadness, joy, love, anger, fear, and surprise. The pipeline trains the classifier, evaluates with a confusion matrix and per-class F1, and explores inference with attention and token-importance visualizations in Flyte reports.

Flyte provides:

- **GPU fine-tuning** with live training loss charts.
- **Rich evaluation reports** including confusion matrices and confidence bars.
- **Cached dataset loading** for repeatable experiments.

## Define the task environments

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "datasets>=3.0.0",
#    "accelerate>=0.34.0",
#    "scikit-learn",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
import os

main_img = flyte.Image.from_uv_script(__file__, name="bert-fine-tuning-emotion", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu=1),
    secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)

cpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)

HF_TOKEN = os.environ.get("HF_TOKEN")
# {{/docs-fragment env}}

from report_helpers import (
    make_attention_text,
    make_bar_chart,
    make_confidence_bars,
    make_confusion_matrix,
    make_line_chart,
    make_token_importance_text,
    pipeline_step_indicator,
    wrap_report,
)

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

EMOTION_LABELS = ["sadness", "joy", "love", "anger", "fear", "surprise"]
EMOTION_DATASET = "dair-ai/emotion"

# ------------------------------------------------------------------
# Task 1: Get data
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def get_data(
    max_train_samples: int = 10000,
    max_eval_samples: int = 2000,
) -> flyte.io.Dir:
    """Download the emotion dataset and save train/eval splits.

    The dair-ai/emotion dataset contains ~20k English Twitter messages labeled
    with one of 6 emotions: sadness, joy, love, anger, fear, surprise.
    """
    from datasets import DatasetDict, load_dataset

    log.info("Loading emotion dataset...")
    ds = load_dataset(EMOTION_DATASET)

    train_ds = ds["train"].shuffle(seed=42).select(range(min(max_train_samples, len(ds["train"]))))
    eval_ds = ds["test"].shuffle(seed=42).select(range(min(max_eval_samples, len(ds["test"]))))

    processed = DatasetDict({"train": train_ds, "eval": eval_ds})

    output_dir = os.path.join(tempfile.mkdtemp(), "dataset")
    processed.save_to_disk(output_dir)
    log.info(f"Dataset ready: {len(train_ds)} train, {len(eval_ds)} eval")

    return await flyte.io.Dir.from_local(output_dir)

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    epochs: int = 3,
    lr: float = 2e-5,
    batch_size: int = 16,
    warmup_steps: int = 100,
) -> flyte.io.Dir:
    """Fine-tune a BERT-style model for 6-class emotion classification."""
    import numpy as np
    import torch
    from datasets import load_from_disk
    from sklearn.metrics import accuracy_score, f1_score
    from transformers import (
        AutoModelForSequenceClassification,
        AutoTokenizer,
        Trainer,
        TrainerCallback,
        TrainingArguments,
    )

    log.info(f"Training: model={model_name}")

    id2label = {i: l for i, l in enumerate(EMOTION_LABELS)}
    label2id = {l: i for i, l in enumerate(EMOTION_LABELS)}

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Loading Model...</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="card"><p>Preparing for emotion classification training...</p></div>'
        ),
        do_flush=True,
    )

    # -- Load data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)

    # -- Tokenize --
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)

    def tokenize(examples):
        return tokenizer(examples["text"], truncation=True, max_length=128, padding="max_length")

    dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])

    # -- Load model --
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        token=HF_TOKEN,
        num_labels=6,
        id2label=id2label,
        label2id=label2id,
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(f"Parameters: {trainable_params:,} / {total_params:,}")

    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        log.info(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")

    # -- Metrics tracking for live report --
    training_log: list[dict] = []
    eval_log: list[dict] = []

    def _build_training_report(max_steps: int) -> str:
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{len(dataset['train']):,}</div><div class="label">Train Samples</div></div>
          <div class="stat"><div class="value">{len(dataset['eval']):,}</div><div class="label">Eval Samples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Parameters</div></div>
        </div>
        """

        charts_html = ""

        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            loss_display = f"Loss: <span class=\"highlight\">{current['loss']:.4f}</span>" if current.get("loss") else ""
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs}
              {f' | {loss_display}' if loss_display else ''}
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_entries = [e for e in training_log if "loss" in e]
            if len(loss_entries) >= 2:
                loss_chart = make_line_chart(
                    data=loss_entries,
                    x_key="epoch",
                    y_keys=["loss"],
                    title="Training Loss",
                    x_label="Epoch",
                    y_label="Loss",
                    colors=["#5a7db5"],
                )
                charts_html += f'<div class="chart-container">{loss_chart}</div>'

        if eval_log:
            latest_eval = eval_log[-1]
            best_acc = max(e.get("accuracy", 0) for e in eval_log)
            best_f1 = max(e.get("f1", 0) for e in eval_log)
            charts_html += f"""
            <div class="stat-grid" style="margin-top:16px;">
              <div class="stat"><div class="value">{latest_eval.get('accuracy', 0):.1%}</div><div class="label">Eval Accuracy</div></div>
              <div class="stat"><div class="value">{latest_eval.get('f1', 0):.1%}</div><div class="label">Eval F1</div></div>
              <div class="stat"><div class="value">{best_acc:.1%}</div><div class="label">Best Accuracy</div></div>
              <div class="stat"><div class="value">{latest_eval.get('eval_loss', 0):.4f}</div><div class="label">Eval Loss</div></div>
            </div>
            """

            if len(eval_log) >= 2:
                eval_chart = make_line_chart(
                    data=eval_log,
                    x_key="epoch",
                    y_keys=["accuracy", "f1"],
                    title="Eval Metrics Over Training",
                    x_label="Epoch",
                    y_label="Score",
                    colors=["#0f3460", "#06d6a0"],
                    y_max_cap=1.05,
                    y_display_names={"accuracy": "Accuracy", "f1": "Weighted F1"},
                )
                charts_html += f'<div class="chart-container">{eval_chart}</div>'

                eval_loss_chart = make_line_chart(
                    data=[e for e in eval_log if "eval_loss" in e],
                    x_key="epoch",
                    y_keys=["eval_loss"],
                    title="Eval Loss",
                    x_label="Epoch",
                    y_label="Loss",
                    colors=["#e63946"],
                )
                if any("eval_loss" in e for e in eval_log):
                    charts_html += f'<div class="chart-container">{eval_loss_chart}</div>'

        return wrap_report(stats_html + charts_html)

    # -- Callbacks --
    class ReportCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
            }
            if "loss" in logs:
                entry["loss"] = round(logs["loss"], 4)
            if "eval_accuracy" in logs:
                eval_log.append({
                    "epoch": entry["epoch"],
                    "accuracy": logs["eval_accuracy"],
                    "f1": logs.get("eval_f1", 0),
                    "eval_loss": logs.get("eval_loss", 0),
                })
            if "loss" in entry:
                training_log.append(entry)

            flyte.report.replace(
                _build_training_report(state.max_steps),
                do_flush=True,
            )

    # -- Compute metrics --
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        return {
            "accuracy": accuracy_score(labels, preds),
            "f1": f1_score(labels, preds, average="weighted"),
        }

    # -- Training --
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        learning_rate=lr,
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        warmup_steps=warmup_steps,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[ReportCallback()],
    )

    log.info("Starting training...")
    await flyte.report.replace.aio(
        _build_training_report(0),
        do_flush=True,
    )

    trainer.train()
    log.info("Training complete.")

    # -- Save model --
    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")
    trainer.save_model(save_dir)
    tokenizer.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Final eval + report --
    metrics = trainer.evaluate()
    final_acc = metrics.get("eval_accuracy", 0)
    final_f1 = metrics.get("eval_f1", 0)

    final_charts = ""
    loss_entries = [e for e in training_log if "loss" in e]
    if len(loss_entries) >= 2:
        loss_chart = make_line_chart(
            data=loss_entries,
            x_key="epoch",
            y_keys=["loss"],
            title="Training Loss",
            x_label="Epoch",
            y_label="Loss",
            colors=["#5a7db5"],
        )
        final_charts += f'<div class="chart-container">{loss_chart}</div>'

    if len(eval_log) >= 2:
        eval_chart = make_line_chart(
            data=eval_log,
            x_key="epoch",
            y_keys=["accuracy", "f1"],
            title="Eval Metrics Over Training",
            x_label="Epoch",
            y_label="Score",
            colors=["#0f3460", "#06d6a0"],
            y_max_cap=1.05,
            y_display_names={"accuracy": "Accuracy", "f1": "Weighted F1"},
        )
        final_charts += f'<div class="chart-container">{eval_chart}</div>'

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Training Complete</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{final_acc:.1%}</div><div class="label">Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{final_f1:.1%}</div><div class="label">Weighted F1</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Parameters</div></div>'
            f'</div>'
            f"{final_charts}"
        ),
        do_flush=True,
    )

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Task 3: Evaluate
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    model_name: str,
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 200,
) -> str:
    """Compare base model (random head) vs fine-tuned on emotion classification.

    Produces confusion matrix, per-class precision/recall/F1, and overall metrics.
    """
    import numpy as np
    import torch
    from datasets import load_from_disk
    from sklearn.metrics import (
        accuracy_score,
        classification_report,
        confusion_matrix as sk_confusion_matrix,
        f1_score,
    )
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Loading models...</p>"),
        do_flush=True,
    )

    # -- Load eval data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"].select(range(min(num_examples, len(dataset["eval"]))))
    texts = eval_ds["text"]
    labels = eval_ds["label"]

    def predict_batch(model, tokenizer, texts, batch_size=32):
        preds = []
        probs_all = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            inputs = tokenizer(batch, truncation=True, max_length=128, padding=True, return_tensors="pt")
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs)
            batch_probs = torch.softmax(outputs.logits, dim=-1).cpu()
            batch_preds = torch.argmax(batch_probs, dim=-1).tolist()
            preds.extend(batch_preds)
            probs_all.extend(batch_probs.tolist())
        return preds, probs_all

    # -- Base model --
    log.info(f"Loading base model: {model_name}")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Running base model (random classifier head)...</p>"),
        do_flush=True,
    )

    base_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
    base_model = AutoModelForSequenceClassification.from_pretrained(
        model_name, token=HF_TOKEN, num_labels=6,
    )
    base_model.eval()
    if torch.cuda.is_available():
        base_model = base_model.cuda()

    base_preds, base_probs = predict_batch(base_model, base_tokenizer, texts)
    del base_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Fine-tuned model --
    log.info("Loading fine-tuned model...")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Running fine-tuned model...</p>"),
        do_flush=True,
    )

    ft_path = await finetuned_dir.download()
    ft_tokenizer = AutoTokenizer.from_pretrained(ft_path)
    ft_model = AutoModelForSequenceClassification.from_pretrained(ft_path)
    ft_model.eval()
    if torch.cuda.is_available():
        ft_model = ft_model.cuda()

    ft_preds, ft_probs = predict_batch(ft_model, ft_tokenizer, texts)
    del ft_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Compute metrics --
    base_acc = accuracy_score(labels, base_preds) * 100
    base_f1 = f1_score(labels, base_preds, average="weighted") * 100
    ft_acc = accuracy_score(labels, ft_preds) * 100
    ft_f1 = f1_score(labels, ft_preds, average="weighted") * 100

    log.info(f"Base:      Accuracy={base_acc:.1f}%, F1={base_f1:.1f}%")
    log.info(f"Fine-tuned: Accuracy={ft_acc:.1f}%, F1={ft_f1:.1f}%")

    # -- Confusion matrix --
    ft_cm = sk_confusion_matrix(labels, ft_preds, labels=list(range(6)))
    cm_list = ft_cm.tolist()
    cm_svg = make_confusion_matrix(cm_list, EMOTION_LABELS, title="Fine-tuned Model — Confusion Matrix")

    # -- Per-class metrics --
    report_dict = classification_report(
        labels, ft_preds, labels=list(range(6)), target_names=EMOTION_LABELS,
        output_dict=True, zero_division=0,
    )
    per_class_html = "<table><tr><th>Emotion</th><th>Precision</th><th>Recall</th><th>F1</th><th>Support</th></tr>"
    for label_name in EMOTION_LABELS:
        if label_name in report_dict:
            m = report_dict[label_name]
            per_class_html += (
                f"<tr><td><b>{label_name}</b></td>"
                f"<td>{m['precision']:.1%}</td>"
                f"<td>{m['recall']:.1%}</td>"
                f"<td>{m['f1-score']:.1%}</td>"
                f"<td>{int(m['support'])}</td></tr>"
            )
    per_class_html += "</table>"

    # -- Bar chart: base vs fine-tuned --
    per_class_base_acc = []
    per_class_ft_acc = []
    for cls_idx in range(6):
        cls_mask = [i for i, l in enumerate(labels) if l == cls_idx]
        if cls_mask:
            base_cls_acc = sum(1 for i in cls_mask if base_preds[i] == cls_idx) / len(cls_mask) * 100
            ft_cls_acc = sum(1 for i in cls_mask if ft_preds[i] == cls_idx) / len(cls_mask) * 100
        else:
            base_cls_acc = 0
            ft_cls_acc = 0
        per_class_base_acc.append(base_cls_acc)
        per_class_ft_acc.append(ft_cls_acc)

    bar_chart = make_bar_chart(
        labels=EMOTION_LABELS,
        series={"Base": per_class_base_acc, "Fine-tuned": per_class_ft_acc},
        title="Per-Class Accuracy — Base vs Fine-tuned",
        colors=["#adb5bd", "#0f3460"],
        y_max_cap=105.0,
    )

    # -- Example predictions --
    improvement = ft_acc - base_acc
    imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"

    examples_html = ""
    for i in range(min(10, len(texts))):
        true_label = EMOTION_LABELS[labels[i]]
        ft_label = EMOTION_LABELS[ft_preds[i]]
        base_label = EMOTION_LABELS[base_preds[i]]
        ft_correct = ft_preds[i] == labels[i]
        base_correct = base_preds[i] == labels[i]
        text_preview = texts[i][:200]

        ft_badge = "badge-success" if ft_correct else "badge-danger"
        base_badge = "badge-success" if base_correct else "badge-danger"

        examples_html += f"""
<div class="card">
  <p style="font-size:0.95em;">"{text_preview}"</p>
  <p>True: <b>{true_label}</b> |
  Base: <span class="badge {base_badge}">{base_label}</span> |
  Fine-tuned: <span class="badge {ft_badge}">{ft_label}</span></p>
</div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation Results — Emotion Classification</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{base_acc:.1f}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{ft_acc:.1f}%</div><div class="label">Fine-tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{ft_f1:.1f}%</div><div class="label">Fine-tuned F1</div></div>'
            f'</div>'
            f'<div class="chart-container">{bar_chart}</div>'
            f'<div class="chart-container">{cm_svg}</div>'
            f"<h3>Per-Class Metrics (Fine-tuned)</h3>"
            f"{per_class_html}"
            f"<h3>Example Predictions</h3>"
            f"{examples_html}"
        ),
        do_flush=True,
    )

    return json.dumps({
        "base_accuracy": round(base_acc, 1),
        "base_f1": round(base_f1, 1),
        "finetuned_accuracy": round(ft_acc, 1),
        "finetuned_f1": round(ft_f1, 1),
        "improvement": round(improvement, 1),
        "num_examples": len(texts),
        "confusion_matrix": cm_list,
        "per_class": {k: report_dict[k] for k in EMOTION_LABELS if k in report_dict},
    })

# ------------------------------------------------------------------
# Task 4: Explore inference
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def explore_inference(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 8,
) -> str:
    """Deep-dive into model behavior with attention and token importance.

    For a set of examples, this task produces:
    1. Predictions with full confidence distribution across all 6 emotions
    2. Attention heatmaps — which tokens the model focuses on for classification
       (CLS token attention from the last layer, averaged across heads)
    3. Token importance via gradient-based attribution — which tokens most
       influence the predicted class (gradient x embedding norm)
    4. Misclassification analysis — confident wrong predictions with explanations
    """
    import numpy as np
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    log.info("Starting explore_inference...")
    await flyte.report.replace.aio(
        wrap_report(
            "<h2>Explore Inference</h2>"
            "<p>Loading model for attention and attribution analysis...</p>"
        ),
        do_flush=True,
    )

    # -- Load model (with eager attention for weight extraction) --
    ft_path = await finetuned_dir.download()
    tokenizer = AutoTokenizer.from_pretrained(ft_path)

    # Need eager attention to extract attention weights (flash attention doesn't return them)
    model = AutoModelForSequenceClassification.from_pretrained(
        ft_path,
        output_attentions=True,
        attn_implementation="eager",
    )
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # -- Load eval data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"]

    # Pick a diverse set of examples — try to get some from each class
    examples_per_class = max(1, num_examples // 6)
    selected_indices = []
    for cls_idx in range(6):
        cls_indices = [i for i in range(len(eval_ds)) if eval_ds[i]["label"] == cls_idx]
        selected_indices.extend(cls_indices[:examples_per_class])
    # Fill remaining with random
    remaining = num_examples - len(selected_indices)
    if remaining > 0:
        other_indices = [i for i in range(len(eval_ds)) if i not in selected_indices]
        selected_indices.extend(other_indices[:remaining])
    selected_indices = selected_indices[:num_examples]

    # -- Analyze each example --
    analyses = []
    for idx_num, ds_idx in enumerate(selected_indices):
        text = eval_ds[ds_idx]["text"]
        true_label = eval_ds[ds_idx]["label"]

        await flyte.report.replace.aio(
            wrap_report(
                f"<h2>Explore Inference</h2>"
                f"<p>Analyzing example {idx_num + 1}/{len(selected_indices)}...</p>"
                f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                f'<div style="background:#0f3460;width:{(idx_num + 1) / len(selected_indices) * 100:.1f}%;height:100%;border-radius:4px;"></div>'
                f'</div>'
            ),
            do_flush=True,
        )

        # Tokenize
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        token_ids = inputs["input_ids"][0]
        tokens = tokenizer.convert_ids_to_tokens(token_ids)

        # Forward pass with attention
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits[0]
        probs = torch.softmax(logits, dim=-1).cpu().tolist()
        pred_idx = int(torch.argmax(logits).item())

        # -- Attention: CLS token attention from last layer --
        # attentions shape: (num_layers, batch, num_heads, seq_len, seq_len)
        last_layer_attention = outputs.attentions[-1][0]  # (num_heads, seq_len, seq_len)
        # Average across heads, take CLS row (index 0)
        cls_attention = last_layer_attention.mean(dim=0)[0].cpu().numpy()  # (seq_len,)

        # Remove [CLS] and [SEP] and padding from visualization
        real_token_mask = []
        clean_tokens = []
        clean_attention = []
        for i, tok in enumerate(tokens):
            if tok in ("[CLS]", "[SEP]", "<s>", "</s>", "[PAD]", "<pad>"):
                continue
            if tok == tokenizer.pad_token:
                continue
            clean_tokens.append(tok)
            clean_attention.append(float(cls_attention[i]))
            real_token_mask.append(i)

        # -- Token importance via gradient attribution --
        # Re-run with gradients enabled on embeddings
        embedding_layer = None
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Embedding) and "word" in name.lower():
                embedding_layer = module
                break
        if embedding_layer is None:
            # Fallback: find the first large embedding
            for name, module in model.named_modules():
                if isinstance(module, torch.nn.Embedding) and module.weight.shape[0] > 1000:
                    embedding_layer = module
                    break

        importance_scores = [0.0] * len(clean_tokens)
        if embedding_layer is not None:
            inputs_grad = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
            inputs_grad = {k: v.to(device) for k, v in inputs_grad.items()}

            embeddings = embedding_layer(inputs_grad["input_ids"])
            embeddings.retain_grad()

            # Run model with embeddings instead of input_ids
            # We need to hook into the model to replace the embedding output
            embedding_output = [None]

            def hook_fn(module, input, output):
                embedding_output[0] = output
                return embeddings.requires_grad_(True)

            handle = embedding_layer.register_forward_hook(hook_fn)

            outputs_grad = model(**inputs_grad)
            handle.remove()

            # Gradient of predicted class w.r.t. embeddings
            pred_score = outputs_grad.logits[0, pred_idx]
            pred_score.backward()

            if embeddings.grad is not None:
                # Token importance = L2 norm of (gradient * embedding) per token
                token_importance = (embeddings.grad[0] * embeddings[0]).norm(dim=-1).detach().cpu().numpy()
                for clean_idx, orig_idx in enumerate(real_token_mask):
                    if orig_idx < len(token_importance):
                        importance_scores[clean_idx] = float(token_importance[orig_idx])

            model.zero_grad()

        analyses.append({
            "text": text,
            "true_label": true_label,
            "pred_idx": pred_idx,
            "probs": probs,
            "tokens": clean_tokens,
            "attention": clean_attention,
            "importance": importance_scores,
            "correct": pred_idx == true_label,
        })

    # -- Build report --
    log.info("Building explore_inference report...")

    # Overall summary
    correct = sum(1 for a in analyses if a["correct"])
    total = len(analyses)

    # Separate correct vs wrong
    correct_analyses = [a for a in analyses if a["correct"]]
    wrong_analyses = [a for a in analyses if not a["correct"]]

    # -- Build example cards --
    examples_html = ""
    for a in analyses:
        true_name = EMOTION_LABELS[a["true_label"]]
        pred_name = EMOTION_LABELS[a["pred_idx"]]
        status_badge = "badge-success" if a["correct"] else "badge-danger"
        status_text = "Correct" if a["correct"] else "Wrong"

        # Confidence bars
        conf_bars = make_confidence_bars(
            labels=EMOTION_LABELS,
            probabilities=a["probs"],
            predicted_idx=a["pred_idx"],
            true_idx=a["true_label"],
        )

        # Attention heatmap
        attention_viz = make_attention_text(
            tokens=a["tokens"],
            weights=a["attention"],
            title="Attention (what the model looks at for its prediction — darker = more attention)",
        )

        # Token importance
        importance_viz = make_token_importance_text(
            tokens=a["tokens"],
            importance=a["importance"],
            title="Token importance (gradient attribution — green = supports prediction, red = opposes)",
        )

        text_preview = a["text"][:300]
        examples_html += f"""
<div class="card">
  <p style="font-size:1em;"><b>"{text_preview}"</b></p>
  <p>True: <b>{true_name}</b> | Predicted: <span class="badge {status_badge}">{pred_name} ({status_text})</span>
     | Confidence: <b>{a['probs'][a['pred_idx']]:.1%}</b></p>
  <div style="margin:12px 0;">{conf_bars}</div>
  <div style="margin:12px 0;">{attention_viz}</div>
  <div style="margin:12px 0;">{importance_viz}</div>
</div>"""

    # -- Misclassification spotlight --
    misclass_html = ""
    if wrong_analyses:
        # Sort by confidence (most confident wrong first)
        wrong_sorted = sorted(wrong_analyses, key=lambda a: a["probs"][a["pred_idx"]], reverse=True)

        misclass_html = "<h3>Misclassification Spotlight</h3>"
        misclass_html += '<div class="note">These are the model\'s most confident wrong predictions — cases where the model is sure but incorrect. These reveal the model\'s blind spots.</div>'

        for a in wrong_sorted[:3]:
            true_name = EMOTION_LABELS[a["true_label"]]
            pred_name = EMOTION_LABELS[a["pred_idx"]]
            conf = a["probs"][a["pred_idx"]]
            true_conf = a["probs"][a["true_label"]]

            misclass_html += f"""
<div class="card" style="border-left:4px solid #e63946;">
  <p><b>"{a['text'][:200]}"</b></p>
  <p>Predicted <span class="badge badge-danger">{pred_name}</span> ({conf:.1%})
     but true label is <span class="badge badge-info">{true_name}</span> ({true_conf:.1%})</p>
  <p style="font-size:0.85em;color:#6c757d;">
     The model assigned {conf:.1%} confidence to {pred_name} vs {true_conf:.1%} to {true_name}.
     {"The model was very sure here — this is a genuine blind spot." if conf > 0.7 else "The model was uncertain — the true class was a close second."}
  </p>
</div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Explore Inference — Attention &amp; Attribution</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{correct}/{total}</div><div class="label">Correct</div></div>'
            f'  <div class="stat"><div class="value">{correct/total:.0%}</div><div class="label">Accuracy (sample)</div></div>'
            f'  <div class="stat"><div class="value">{len(wrong_analyses)}</div><div class="label">Errors to Analyze</div></div>'
            f'</div>'
            f'<div class="note">'
            f'<b>How to read the visualizations below:</b><br/>'
            f'<b>Attention heatmap:</b> Shows which tokens the [CLS] token attends to in the final layer '
            f'(averaged across all attention heads). Darker = more attention. This reveals what the model "looks at" when making its classification decision.<br/>'
            f'<b>Token importance:</b> Gradient-based attribution showing which tokens most influence the prediction. '
            f'Green = supports the prediction, Red = opposes it. Computed as gradient &times; embedding norm.'
            f'</div>'
            f"<h3>Example Analysis</h3>"
            f"{examples_html}"
            f"{misclass_html}"
        ),
        do_flush=True,
    )

    return json.dumps({
        "num_examples": total,
        "correct": correct,
        "accuracy": round(correct / total * 100, 1),
        "num_misclassifications": len(wrong_analyses),
        "analyses": [
            {
                "text": a["text"][:200],
                "true_label": EMOTION_LABELS[a["true_label"]],
                "predicted": EMOTION_LABELS[a["pred_idx"]],
                "confidence": round(a["probs"][a["pred_idx"]], 3),
                "correct": a["correct"],
            }
            for a in analyses
        ],
    })

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "answerdotai/ModernBERT-base",
    epochs: int = 3,
    lr: float = 2e-5,
    batch_size: int = 16,
    warmup_steps: int = 100,
    max_train_samples: int = 10000,
    max_eval_samples: int = 2000,
    num_eval_examples: int = 200,
    num_explore_examples: int = 12,
) -> flyte.io.Dir:
    """
    ModernBERT emotion classification pipeline.

    Returns the fine-tuned model directory (used by serve.py for deployment).

    1. Download emotion dataset (6 classes from Twitter text)
    2. Fine-tune ModernBERT for sequence classification
    3. Evaluate: base vs fine-tuned with confusion matrix
    4. Explore inference: attention heatmaps + token importance

    Args:
        model_name: HuggingFace encoder model to fine-tune.
        num_explore_examples: Number of examples for attention/attribution analysis.
    """
    log.info(f"Pipeline: {model_name} | emotion classification")
    steps = ["Get Data", "Train", "Evaluate", "Explore Inference"]

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(0, steps)}"
            f'<div class="card"><p>Downloading emotion dataset...</p></div>'
        ),
        do_flush=True,
    )

    # Step 1: Get data
    data_dir = await get_data(max_train_samples, max_eval_samples)

    # Step 2: Train
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(1, steps)}"
            f'<div class="card"><p>Fine-tuning for emotion classification...</p></div>'
        ),
        do_flush=True,
    )

    finetuned_dir = await train(model_name, data_dir, epochs, lr, batch_size, warmup_steps)

    # Step 3: Evaluate
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(2, steps)}"
            f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
        ),
        do_flush=True,
    )

    eval_result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
    eval_metrics = json.loads(eval_result)

    # Step 4: Explore inference
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(3, steps)}"
            f'<div class="card"><p>Analyzing attention patterns and token importance...</p></div>'
        ),
        do_flush=True,
    )

    explore_result = await explore_inference(finetuned_dir, data_dir, num_explore_examples)

    # -- Final report --
    improvement = eval_metrics["improvement"]
    imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline Complete</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(4, steps)}"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{eval_metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_accuracy"]}%</div><div class="label">Fine-tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_f1"]}%</div><div class="label">Weighted F1</div></div>'
            f'</div>'
        ),
        do_flush=True,
    )

    log.info(f"Pipeline complete. Accuracy improvement: {improvement:+.1f}pp")
    return finetuned_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/bert_fine_tuning_emotion/bert_fine_tuning_emotion.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "datasets>=3.0.0",
#    "scikit-learn",
#    ...
# ]
# ///
```

## Orchestrate the pipeline

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "datasets>=3.0.0",
#    "accelerate>=0.34.0",
#    "scikit-learn",
#    "numpy",
# ]
# main = "pipeline"
# params = ""
# ///
import json
import logging
import os
import tempfile

import flyte
import flyte.io
import flyte.report

# {{docs-fragment env}}
import os

main_img = flyte.Image.from_uv_script(__file__, name="bert-fine-tuning-emotion", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu=1),
    secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)

cpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)

HF_TOKEN = os.environ.get("HF_TOKEN")
# {{/docs-fragment env}}

from report_helpers import (
    make_attention_text,
    make_bar_chart,
    make_confidence_bars,
    make_confusion_matrix,
    make_line_chart,
    make_token_importance_text,
    pipeline_step_indicator,
    wrap_report,
)

logging.basicConfig(level=logging.WARNING, format="%(message)s", force=True)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

EMOTION_LABELS = ["sadness", "joy", "love", "anger", "fear", "surprise"]
EMOTION_DATASET = "dair-ai/emotion"

# ------------------------------------------------------------------
# Task 1: Get data
# ------------------------------------------------------------------

@cpu_env.task(cache="auto")
async def get_data(
    max_train_samples: int = 10000,
    max_eval_samples: int = 2000,
) -> flyte.io.Dir:
    """Download the emotion dataset and save train/eval splits.

    The dair-ai/emotion dataset contains ~20k English Twitter messages labeled
    with one of 6 emotions: sadness, joy, love, anger, fear, surprise.
    """
    from datasets import DatasetDict, load_dataset

    log.info("Loading emotion dataset...")
    ds = load_dataset(EMOTION_DATASET)

    train_ds = ds["train"].shuffle(seed=42).select(range(min(max_train_samples, len(ds["train"]))))
    eval_ds = ds["test"].shuffle(seed=42).select(range(min(max_eval_samples, len(ds["test"]))))

    processed = DatasetDict({"train": train_ds, "eval": eval_ds})

    output_dir = os.path.join(tempfile.mkdtemp(), "dataset")
    processed.save_to_disk(output_dir)
    log.info(f"Dataset ready: {len(train_ds)} train, {len(eval_ds)} eval")

    return await flyte.io.Dir.from_local(output_dir)

# ------------------------------------------------------------------
# Task 2: Train
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def train(
    model_name: str,
    data_dir: flyte.io.Dir,
    epochs: int = 3,
    lr: float = 2e-5,
    batch_size: int = 16,
    warmup_steps: int = 100,
) -> flyte.io.Dir:
    """Fine-tune a BERT-style model for 6-class emotion classification."""
    import numpy as np
    import torch
    from datasets import load_from_disk
    from sklearn.metrics import accuracy_score, f1_score
    from transformers import (
        AutoModelForSequenceClassification,
        AutoTokenizer,
        Trainer,
        TrainerCallback,
        TrainingArguments,
    )

    log.info(f"Training: model={model_name}")

    id2label = {i: l for i, l in enumerate(EMOTION_LABELS)}
    label2id = {l: i for i, l in enumerate(EMOTION_LABELS)}

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Loading Model...</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="card"><p>Preparing for emotion classification training...</p></div>'
        ),
        do_flush=True,
    )

    # -- Load data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)

    # -- Tokenize --
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)

    def tokenize(examples):
        return tokenizer(examples["text"], truncation=True, max_length=128, padding="max_length")

    dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])

    # -- Load model --
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        token=HF_TOKEN,
        num_labels=6,
        id2label=id2label,
        label2id=label2id,
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(f"Parameters: {trainable_params:,} / {total_params:,}")

    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        log.info(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")

    # -- Metrics tracking for live report --
    training_log: list[dict] = []
    eval_log: list[dict] = []

    def _build_training_report(max_steps: int) -> str:
        stats_html = f"""
        <h2>Training in Progress...</h2>
        <h3>{model_name}</h3>
        <div class="stat-grid">
          <div class="stat"><div class="value">{len(dataset['train']):,}</div><div class="label">Train Samples</div></div>
          <div class="stat"><div class="value">{len(dataset['eval']):,}</div><div class="label">Eval Samples</div></div>
          <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>
          <div class="stat"><div class="value">{lr}</div><div class="label">Learning Rate</div></div>
          <div class="stat"><div class="value">{batch_size}</div><div class="label">Batch Size</div></div>
          <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Parameters</div></div>
        </div>
        """

        charts_html = ""

        if training_log:
            current = training_log[-1]
            progress_pct = current["step"] / max_steps * 100 if max_steps else 0
            loss_display = f"Loss: <span class=\"highlight\">{current['loss']:.4f}</span>" if current.get("loss") else ""
            charts_html += f"""
            <div class="card">
              <b>Step {current['step']}/{max_steps}</b>
              ({progress_pct:.0f}%) |
              Epoch {current['epoch']:.2f}/{epochs}
              {f' | {loss_display}' if loss_display else ''}
              <div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">
                <div style="background:#0f3460;width:{progress_pct:.1f}%;height:100%;border-radius:4px;"></div>
              </div>
            </div>
            """

            loss_entries = [e for e in training_log if "loss" in e]
            if len(loss_entries) >= 2:
                loss_chart = make_line_chart(
                    data=loss_entries,
                    x_key="epoch",
                    y_keys=["loss"],
                    title="Training Loss",
                    x_label="Epoch",
                    y_label="Loss",
                    colors=["#5a7db5"],
                )
                charts_html += f'<div class="chart-container">{loss_chart}</div>'

        if eval_log:
            latest_eval = eval_log[-1]
            best_acc = max(e.get("accuracy", 0) for e in eval_log)
            best_f1 = max(e.get("f1", 0) for e in eval_log)
            charts_html += f"""
            <div class="stat-grid" style="margin-top:16px;">
              <div class="stat"><div class="value">{latest_eval.get('accuracy', 0):.1%}</div><div class="label">Eval Accuracy</div></div>
              <div class="stat"><div class="value">{latest_eval.get('f1', 0):.1%}</div><div class="label">Eval F1</div></div>
              <div class="stat"><div class="value">{best_acc:.1%}</div><div class="label">Best Accuracy</div></div>
              <div class="stat"><div class="value">{latest_eval.get('eval_loss', 0):.4f}</div><div class="label">Eval Loss</div></div>
            </div>
            """

            if len(eval_log) >= 2:
                eval_chart = make_line_chart(
                    data=eval_log,
                    x_key="epoch",
                    y_keys=["accuracy", "f1"],
                    title="Eval Metrics Over Training",
                    x_label="Epoch",
                    y_label="Score",
                    colors=["#0f3460", "#06d6a0"],
                    y_max_cap=1.05,
                    y_display_names={"accuracy": "Accuracy", "f1": "Weighted F1"},
                )
                charts_html += f'<div class="chart-container">{eval_chart}</div>'

                eval_loss_chart = make_line_chart(
                    data=[e for e in eval_log if "eval_loss" in e],
                    x_key="epoch",
                    y_keys=["eval_loss"],
                    title="Eval Loss",
                    x_label="Epoch",
                    y_label="Loss",
                    colors=["#e63946"],
                )
                if any("eval_loss" in e for e in eval_log):
                    charts_html += f'<div class="chart-container">{eval_loss_chart}</div>'

        return wrap_report(stats_html + charts_html)

    # -- Callbacks --
    class ReportCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if not logs:
                return
            entry = {
                "step": state.global_step,
                "epoch": round(logs.get("epoch", 0), 2),
            }
            if "loss" in logs:
                entry["loss"] = round(logs["loss"], 4)
            if "eval_accuracy" in logs:
                eval_log.append({
                    "epoch": entry["epoch"],
                    "accuracy": logs["eval_accuracy"],
                    "f1": logs.get("eval_f1", 0),
                    "eval_loss": logs.get("eval_loss", 0),
                })
            if "loss" in entry:
                training_log.append(entry)

            flyte.report.replace(
                _build_training_report(state.max_steps),
                do_flush=True,
            )

    # -- Compute metrics --
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        return {
            "accuracy": accuracy_score(labels, preds),
            "f1": f1_score(labels, preds, average="weighted"),
        }

    # -- Training --
    output_dir = os.path.join(tempfile.mkdtemp(), "checkpoints")
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        learning_rate=lr,
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        bf16=use_bf16,
        fp16=not use_bf16 and torch.cuda.is_available(),
        warmup_steps=warmup_steps,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["eval"],
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[ReportCallback()],
    )

    log.info("Starting training...")
    await flyte.report.replace.aio(
        _build_training_report(0),
        do_flush=True,
    )

    trainer.train()
    log.info("Training complete.")

    # -- Save model --
    save_dir = os.path.join(tempfile.mkdtemp(), "finetuned_model")
    trainer.save_model(save_dir)
    tokenizer.save_pretrained(save_dir)
    log.info(f"Model saved to {save_dir}")

    # -- Final eval + report --
    metrics = trainer.evaluate()
    final_acc = metrics.get("eval_accuracy", 0)
    final_f1 = metrics.get("eval_f1", 0)

    final_charts = ""
    loss_entries = [e for e in training_log if "loss" in e]
    if len(loss_entries) >= 2:
        loss_chart = make_line_chart(
            data=loss_entries,
            x_key="epoch",
            y_keys=["loss"],
            title="Training Loss",
            x_label="Epoch",
            y_label="Loss",
            colors=["#5a7db5"],
        )
        final_charts += f'<div class="chart-container">{loss_chart}</div>'

    if len(eval_log) >= 2:
        eval_chart = make_line_chart(
            data=eval_log,
            x_key="epoch",
            y_keys=["accuracy", "f1"],
            title="Eval Metrics Over Training",
            x_label="Epoch",
            y_label="Score",
            colors=["#0f3460", "#06d6a0"],
            y_max_cap=1.05,
            y_display_names={"accuracy": "Accuracy", "f1": "Weighted F1"},
        )
        final_charts += f'<div class="chart-container">{eval_chart}</div>'

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Training Complete</h2>"
            f"<h3>{model_name}</h3>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{final_acc:.1%}</div><div class="label">Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{final_f1:.1%}</div><div class="label">Weighted F1</div></div>'
            f'  <div class="stat"><div class="value">{epochs}</div><div class="label">Epochs</div></div>'
            f'  <div class="stat"><div class="value">{trainable_params:,}</div><div class="label">Parameters</div></div>'
            f'</div>'
            f"{final_charts}"
        ),
        do_flush=True,
    )

    return await flyte.io.Dir.from_local(save_dir)

# ------------------------------------------------------------------
# Task 3: Evaluate
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def evaluate(
    model_name: str,
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 200,
) -> str:
    """Compare base model (random head) vs fine-tuned on emotion classification.

    Produces confusion matrix, per-class precision/recall/F1, and overall metrics.
    """
    import numpy as np
    import torch
    from datasets import load_from_disk
    from sklearn.metrics import (
        accuracy_score,
        classification_report,
        confusion_matrix as sk_confusion_matrix,
        f1_score,
    )
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    log.info("Starting evaluation...")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Loading models...</p>"),
        do_flush=True,
    )

    # -- Load eval data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"].select(range(min(num_examples, len(dataset["eval"]))))
    texts = eval_ds["text"]
    labels = eval_ds["label"]

    def predict_batch(model, tokenizer, texts, batch_size=32):
        preds = []
        probs_all = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            inputs = tokenizer(batch, truncation=True, max_length=128, padding=True, return_tensors="pt")
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs)
            batch_probs = torch.softmax(outputs.logits, dim=-1).cpu()
            batch_preds = torch.argmax(batch_probs, dim=-1).tolist()
            preds.extend(batch_preds)
            probs_all.extend(batch_probs.tolist())
        return preds, probs_all

    # -- Base model --
    log.info(f"Loading base model: {model_name}")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Running base model (random classifier head)...</p>"),
        do_flush=True,
    )

    base_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
    base_model = AutoModelForSequenceClassification.from_pretrained(
        model_name, token=HF_TOKEN, num_labels=6,
    )
    base_model.eval()
    if torch.cuda.is_available():
        base_model = base_model.cuda()

    base_preds, base_probs = predict_batch(base_model, base_tokenizer, texts)
    del base_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Fine-tuned model --
    log.info("Loading fine-tuned model...")
    await flyte.report.replace.aio(
        wrap_report("<h2>Evaluation</h2><p>Running fine-tuned model...</p>"),
        do_flush=True,
    )

    ft_path = await finetuned_dir.download()
    ft_tokenizer = AutoTokenizer.from_pretrained(ft_path)
    ft_model = AutoModelForSequenceClassification.from_pretrained(ft_path)
    ft_model.eval()
    if torch.cuda.is_available():
        ft_model = ft_model.cuda()

    ft_preds, ft_probs = predict_batch(ft_model, ft_tokenizer, texts)
    del ft_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -- Compute metrics --
    base_acc = accuracy_score(labels, base_preds) * 100
    base_f1 = f1_score(labels, base_preds, average="weighted") * 100
    ft_acc = accuracy_score(labels, ft_preds) * 100
    ft_f1 = f1_score(labels, ft_preds, average="weighted") * 100

    log.info(f"Base:      Accuracy={base_acc:.1f}%, F1={base_f1:.1f}%")
    log.info(f"Fine-tuned: Accuracy={ft_acc:.1f}%, F1={ft_f1:.1f}%")

    # -- Confusion matrix --
    ft_cm = sk_confusion_matrix(labels, ft_preds, labels=list(range(6)))
    cm_list = ft_cm.tolist()
    cm_svg = make_confusion_matrix(cm_list, EMOTION_LABELS, title="Fine-tuned Model — Confusion Matrix")

    # -- Per-class metrics --
    report_dict = classification_report(
        labels, ft_preds, labels=list(range(6)), target_names=EMOTION_LABELS,
        output_dict=True, zero_division=0,
    )
    per_class_html = "<table><tr><th>Emotion</th><th>Precision</th><th>Recall</th><th>F1</th><th>Support</th></tr>"
    for label_name in EMOTION_LABELS:
        if label_name in report_dict:
            m = report_dict[label_name]
            per_class_html += (
                f"<tr><td><b>{label_name}</b></td>"
                f"<td>{m['precision']:.1%}</td>"
                f"<td>{m['recall']:.1%}</td>"
                f"<td>{m['f1-score']:.1%}</td>"
                f"<td>{int(m['support'])}</td></tr>"
            )
    per_class_html += "</table>"

    # -- Bar chart: base vs fine-tuned --
    per_class_base_acc = []
    per_class_ft_acc = []
    for cls_idx in range(6):
        cls_mask = [i for i, l in enumerate(labels) if l == cls_idx]
        if cls_mask:
            base_cls_acc = sum(1 for i in cls_mask if base_preds[i] == cls_idx) / len(cls_mask) * 100
            ft_cls_acc = sum(1 for i in cls_mask if ft_preds[i] == cls_idx) / len(cls_mask) * 100
        else:
            base_cls_acc = 0
            ft_cls_acc = 0
        per_class_base_acc.append(base_cls_acc)
        per_class_ft_acc.append(ft_cls_acc)

    bar_chart = make_bar_chart(
        labels=EMOTION_LABELS,
        series={"Base": per_class_base_acc, "Fine-tuned": per_class_ft_acc},
        title="Per-Class Accuracy — Base vs Fine-tuned",
        colors=["#adb5bd", "#0f3460"],
        y_max_cap=105.0,
    )

    # -- Example predictions --
    improvement = ft_acc - base_acc
    imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"

    examples_html = ""
    for i in range(min(10, len(texts))):
        true_label = EMOTION_LABELS[labels[i]]
        ft_label = EMOTION_LABELS[ft_preds[i]]
        base_label = EMOTION_LABELS[base_preds[i]]
        ft_correct = ft_preds[i] == labels[i]
        base_correct = base_preds[i] == labels[i]
        text_preview = texts[i][:200]

        ft_badge = "badge-success" if ft_correct else "badge-danger"
        base_badge = "badge-success" if base_correct else "badge-danger"

        examples_html += f"""
<div class="card">
  <p style="font-size:0.95em;">"{text_preview}"</p>
  <p>True: <b>{true_label}</b> |
  Base: <span class="badge {base_badge}">{base_label}</span> |
  Fine-tuned: <span class="badge {ft_badge}">{ft_label}</span></p>
</div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Evaluation Results — Emotion Classification</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{base_acc:.1f}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{ft_acc:.1f}%</div><div class="label">Fine-tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{ft_f1:.1f}%</div><div class="label">Fine-tuned F1</div></div>'
            f'</div>'
            f'<div class="chart-container">{bar_chart}</div>'
            f'<div class="chart-container">{cm_svg}</div>'
            f"<h3>Per-Class Metrics (Fine-tuned)</h3>"
            f"{per_class_html}"
            f"<h3>Example Predictions</h3>"
            f"{examples_html}"
        ),
        do_flush=True,
    )

    return json.dumps({
        "base_accuracy": round(base_acc, 1),
        "base_f1": round(base_f1, 1),
        "finetuned_accuracy": round(ft_acc, 1),
        "finetuned_f1": round(ft_f1, 1),
        "improvement": round(improvement, 1),
        "num_examples": len(texts),
        "confusion_matrix": cm_list,
        "per_class": {k: report_dict[k] for k in EMOTION_LABELS if k in report_dict},
    })

# ------------------------------------------------------------------
# Task 4: Explore inference
# ------------------------------------------------------------------

@gpu_env.task(report=True)
async def explore_inference(
    finetuned_dir: flyte.io.Dir,
    data_dir: flyte.io.Dir,
    num_examples: int = 8,
) -> str:
    """Deep-dive into model behavior with attention and token importance.

    For a set of examples, this task produces:
    1. Predictions with full confidence distribution across all 6 emotions
    2. Attention heatmaps — which tokens the model focuses on for classification
       (CLS token attention from the last layer, averaged across heads)
    3. Token importance via gradient-based attribution — which tokens most
       influence the predicted class (gradient x embedding norm)
    4. Misclassification analysis — confident wrong predictions with explanations
    """
    import numpy as np
    import torch
    from datasets import load_from_disk
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    log.info("Starting explore_inference...")
    await flyte.report.replace.aio(
        wrap_report(
            "<h2>Explore Inference</h2>"
            "<p>Loading model for attention and attribution analysis...</p>"
        ),
        do_flush=True,
    )

    # -- Load model (with eager attention for weight extraction) --
    ft_path = await finetuned_dir.download()
    tokenizer = AutoTokenizer.from_pretrained(ft_path)

    # Need eager attention to extract attention weights (flash attention doesn't return them)
    model = AutoModelForSequenceClassification.from_pretrained(
        ft_path,
        output_attentions=True,
        attn_implementation="eager",
    )
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # -- Load eval data --
    data_path = await data_dir.download()
    dataset = load_from_disk(data_path)
    eval_ds = dataset["eval"]

    # Pick a diverse set of examples — try to get some from each class
    examples_per_class = max(1, num_examples // 6)
    selected_indices = []
    for cls_idx in range(6):
        cls_indices = [i for i in range(len(eval_ds)) if eval_ds[i]["label"] == cls_idx]
        selected_indices.extend(cls_indices[:examples_per_class])
    # Fill remaining with random
    remaining = num_examples - len(selected_indices)
    if remaining > 0:
        other_indices = [i for i in range(len(eval_ds)) if i not in selected_indices]
        selected_indices.extend(other_indices[:remaining])
    selected_indices = selected_indices[:num_examples]

    # -- Analyze each example --
    analyses = []
    for idx_num, ds_idx in enumerate(selected_indices):
        text = eval_ds[ds_idx]["text"]
        true_label = eval_ds[ds_idx]["label"]

        await flyte.report.replace.aio(
            wrap_report(
                f"<h2>Explore Inference</h2>"
                f"<p>Analyzing example {idx_num + 1}/{len(selected_indices)}...</p>"
                f'<div style="background:#e9ecef;border-radius:4px;height:8px;margin-top:8px;">'
                f'<div style="background:#0f3460;width:{(idx_num + 1) / len(selected_indices) * 100:.1f}%;height:100%;border-radius:4px;"></div>'
                f'</div>'
            ),
            do_flush=True,
        )

        # Tokenize
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        token_ids = inputs["input_ids"][0]
        tokens = tokenizer.convert_ids_to_tokens(token_ids)

        # Forward pass with attention
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits[0]
        probs = torch.softmax(logits, dim=-1).cpu().tolist()
        pred_idx = int(torch.argmax(logits).item())

        # -- Attention: CLS token attention from last layer --
        # attentions shape: (num_layers, batch, num_heads, seq_len, seq_len)
        last_layer_attention = outputs.attentions[-1][0]  # (num_heads, seq_len, seq_len)
        # Average across heads, take CLS row (index 0)
        cls_attention = last_layer_attention.mean(dim=0)[0].cpu().numpy()  # (seq_len,)

        # Remove [CLS] and [SEP] and padding from visualization
        real_token_mask = []
        clean_tokens = []
        clean_attention = []
        for i, tok in enumerate(tokens):
            if tok in ("[CLS]", "[SEP]", "<s>", "</s>", "[PAD]", "<pad>"):
                continue
            if tok == tokenizer.pad_token:
                continue
            clean_tokens.append(tok)
            clean_attention.append(float(cls_attention[i]))
            real_token_mask.append(i)

        # -- Token importance via gradient attribution --
        # Re-run with gradients enabled on embeddings
        embedding_layer = None
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Embedding) and "word" in name.lower():
                embedding_layer = module
                break
        if embedding_layer is None:
            # Fallback: find the first large embedding
            for name, module in model.named_modules():
                if isinstance(module, torch.nn.Embedding) and module.weight.shape[0] > 1000:
                    embedding_layer = module
                    break

        importance_scores = [0.0] * len(clean_tokens)
        if embedding_layer is not None:
            inputs_grad = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
            inputs_grad = {k: v.to(device) for k, v in inputs_grad.items()}

            embeddings = embedding_layer(inputs_grad["input_ids"])
            embeddings.retain_grad()

            # Run model with embeddings instead of input_ids
            # We need to hook into the model to replace the embedding output
            embedding_output = [None]

            def hook_fn(module, input, output):
                embedding_output[0] = output
                return embeddings.requires_grad_(True)

            handle = embedding_layer.register_forward_hook(hook_fn)

            outputs_grad = model(**inputs_grad)
            handle.remove()

            # Gradient of predicted class w.r.t. embeddings
            pred_score = outputs_grad.logits[0, pred_idx]
            pred_score.backward()

            if embeddings.grad is not None:
                # Token importance = L2 norm of (gradient * embedding) per token
                token_importance = (embeddings.grad[0] * embeddings[0]).norm(dim=-1).detach().cpu().numpy()
                for clean_idx, orig_idx in enumerate(real_token_mask):
                    if orig_idx < len(token_importance):
                        importance_scores[clean_idx] = float(token_importance[orig_idx])

            model.zero_grad()

        analyses.append({
            "text": text,
            "true_label": true_label,
            "pred_idx": pred_idx,
            "probs": probs,
            "tokens": clean_tokens,
            "attention": clean_attention,
            "importance": importance_scores,
            "correct": pred_idx == true_label,
        })

    # -- Build report --
    log.info("Building explore_inference report...")

    # Overall summary
    correct = sum(1 for a in analyses if a["correct"])
    total = len(analyses)

    # Separate correct vs wrong
    correct_analyses = [a for a in analyses if a["correct"]]
    wrong_analyses = [a for a in analyses if not a["correct"]]

    # -- Build example cards --
    examples_html = ""
    for a in analyses:
        true_name = EMOTION_LABELS[a["true_label"]]
        pred_name = EMOTION_LABELS[a["pred_idx"]]
        status_badge = "badge-success" if a["correct"] else "badge-danger"
        status_text = "Correct" if a["correct"] else "Wrong"

        # Confidence bars
        conf_bars = make_confidence_bars(
            labels=EMOTION_LABELS,
            probabilities=a["probs"],
            predicted_idx=a["pred_idx"],
            true_idx=a["true_label"],
        )

        # Attention heatmap
        attention_viz = make_attention_text(
            tokens=a["tokens"],
            weights=a["attention"],
            title="Attention (what the model looks at for its prediction — darker = more attention)",
        )

        # Token importance
        importance_viz = make_token_importance_text(
            tokens=a["tokens"],
            importance=a["importance"],
            title="Token importance (gradient attribution — green = supports prediction, red = opposes)",
        )

        text_preview = a["text"][:300]
        examples_html += f"""
<div class="card">
  <p style="font-size:1em;"><b>"{text_preview}"</b></p>
  <p>True: <b>{true_name}</b> | Predicted: <span class="badge {status_badge}">{pred_name} ({status_text})</span>
     | Confidence: <b>{a['probs'][a['pred_idx']]:.1%}</b></p>
  <div style="margin:12px 0;">{conf_bars}</div>
  <div style="margin:12px 0;">{attention_viz}</div>
  <div style="margin:12px 0;">{importance_viz}</div>
</div>"""

    # -- Misclassification spotlight --
    misclass_html = ""
    if wrong_analyses:
        # Sort by confidence (most confident wrong first)
        wrong_sorted = sorted(wrong_analyses, key=lambda a: a["probs"][a["pred_idx"]], reverse=True)

        misclass_html = "<h3>Misclassification Spotlight</h3>"
        misclass_html += '<div class="note">These are the model\'s most confident wrong predictions — cases where the model is sure but incorrect. These reveal the model\'s blind spots.</div>'

        for a in wrong_sorted[:3]:
            true_name = EMOTION_LABELS[a["true_label"]]
            pred_name = EMOTION_LABELS[a["pred_idx"]]
            conf = a["probs"][a["pred_idx"]]
            true_conf = a["probs"][a["true_label"]]

            misclass_html += f"""
<div class="card" style="border-left:4px solid #e63946;">
  <p><b>"{a['text'][:200]}"</b></p>
  <p>Predicted <span class="badge badge-danger">{pred_name}</span> ({conf:.1%})
     but true label is <span class="badge badge-info">{true_name}</span> ({true_conf:.1%})</p>
  <p style="font-size:0.85em;color:#6c757d;">
     The model assigned {conf:.1%} confidence to {pred_name} vs {true_conf:.1%} to {true_name}.
     {"The model was very sure here — this is a genuine blind spot." if conf > 0.7 else "The model was uncertain — the true class was a close second."}
  </p>
</div>"""

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Explore Inference — Attention &amp; Attribution</h2>"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{correct}/{total}</div><div class="label">Correct</div></div>'
            f'  <div class="stat"><div class="value">{correct/total:.0%}</div><div class="label">Accuracy (sample)</div></div>'
            f'  <div class="stat"><div class="value">{len(wrong_analyses)}</div><div class="label">Errors to Analyze</div></div>'
            f'</div>'
            f'<div class="note">'
            f'<b>How to read the visualizations below:</b><br/>'
            f'<b>Attention heatmap:</b> Shows which tokens the [CLS] token attends to in the final layer '
            f'(averaged across all attention heads). Darker = more attention. This reveals what the model "looks at" when making its classification decision.<br/>'
            f'<b>Token importance:</b> Gradient-based attribution showing which tokens most influence the prediction. '
            f'Green = supports the prediction, Red = opposes it. Computed as gradient &times; embedding norm.'
            f'</div>'
            f"<h3>Example Analysis</h3>"
            f"{examples_html}"
            f"{misclass_html}"
        ),
        do_flush=True,
    )

    return json.dumps({
        "num_examples": total,
        "correct": correct,
        "accuracy": round(correct / total * 100, 1),
        "num_misclassifications": len(wrong_analyses),
        "analyses": [
            {
                "text": a["text"][:200],
                "true_label": EMOTION_LABELS[a["true_label"]],
                "predicted": EMOTION_LABELS[a["pred_idx"]],
                "confidence": round(a["probs"][a["pred_idx"]], 3),
                "correct": a["correct"],
            }
            for a in analyses
        ],
    })

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------

# {{docs-fragment pipeline}}
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "answerdotai/ModernBERT-base",
    epochs: int = 3,
    lr: float = 2e-5,
    batch_size: int = 16,
    warmup_steps: int = 100,
    max_train_samples: int = 10000,
    max_eval_samples: int = 2000,
    num_eval_examples: int = 200,
    num_explore_examples: int = 12,
) -> flyte.io.Dir:
    """
    ModernBERT emotion classification pipeline.

    Returns the fine-tuned model directory (used by serve.py for deployment).

    1. Download emotion dataset (6 classes from Twitter text)
    2. Fine-tune ModernBERT for sequence classification
    3. Evaluate: base vs fine-tuned with confusion matrix
    4. Explore inference: attention heatmaps + token importance

    Args:
        model_name: HuggingFace encoder model to fine-tune.
        num_explore_examples: Number of examples for attention/attribution analysis.
    """
    log.info(f"Pipeline: {model_name} | emotion classification")
    steps = ["Get Data", "Train", "Evaluate", "Explore Inference"]

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(0, steps)}"
            f'<div class="card"><p>Downloading emotion dataset...</p></div>'
        ),
        do_flush=True,
    )

    # Step 1: Get data
    data_dir = await get_data(max_train_samples, max_eval_samples)

    # Step 2: Train
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(1, steps)}"
            f'<div class="card"><p>Fine-tuning for emotion classification...</p></div>'
        ),
        do_flush=True,
    )

    finetuned_dir = await train(model_name, data_dir, epochs, lr, batch_size, warmup_steps)

    # Step 3: Evaluate
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(2, steps)}"
            f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
        ),
        do_flush=True,
    )

    eval_result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
    eval_metrics = json.loads(eval_result)

    # Step 4: Explore inference
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(3, steps)}"
            f'<div class="card"><p>Analyzing attention patterns and token importance...</p></div>'
        ),
        do_flush=True,
    )

    explore_result = await explore_inference(finetuned_dir, data_dir, num_explore_examples)

    # -- Final report --
    improvement = eval_metrics["improvement"]
    imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline Complete</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(4, steps)}"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{eval_metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_accuracy"]}%</div><div class="label">Fine-tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_f1"]}%</div><div class="label">Weighted F1</div></div>'
            f'</div>'
        ),
        do_flush=True,
    )

    log.info(f"Pipeline complete. Accuracy improvement: {improvement:+.1f}pp")
    return finetuned_dir

# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(pipeline)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/bert_fine_tuning_emotion/bert_fine_tuning_emotion.py*

## Run the workflow

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/bert_fine_tuning_emotion):

```
cd v2/tutorials/bert_fine_tuning_emotion
uv run --script bert_fine_tuning_emotion.py
```

Quick smoke test with a small sample:

```
flyte run bert_fine_tuning_emotion.py pipeline --max_train_samples 200 --max_eval_samples 50 --epochs 1
```

Open the **evaluate** and **explore_inference** task reports for confusion matrices and attention visualizations.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/model-training/hpo ===

# Hyperparameter optimization

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/ml/optimizer.py).

Hyperparameter Optimization (HPO) is a critical step in the machine learning (ML) lifecycle. Hyperparameters are the knobs and dials of a model—values such as learning rates, tree depths, or dropout rates that significantly impact performance but cannot be learned during training. Instead, we must select them manually or optimize them through guided search.

Model developers often enjoy the flexibility of choosing from a wide variety of model types, whether gradient boosted machines (GBMs), generalized linear models (GLMs), deep learning architectures, or dozens of others. A common challenge across all these options is the need to systematically explore model performance across hyperparameter configurations tailored to the specific dataset and task.

Thankfully, this exploration can be automated. Frameworks like [Optuna](https://optuna.org/), [Hyperopt](https://hyperopt.github.io/hyperopt/), and [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) use advanced sampling algorithms to efficiently search the hyperparameter space and identify optimal configurations. HPO may be executed in two distinct ways:

- **Serial HPO** runs one trial at a time, which is easy to set up but can be painfully slow.
- **Parallel HPO** distributes trials across multiple processes. It typically follows a pattern with two parameters: **_N_**, the total number of trials to run, and **_C_**, the maximum number of trials that can run concurrently. Trials are executed asynchronously, and new ones are scheduled based on the results and status of completed or in-progress ones.

However, parallel HPO introduces a new complexity: the need for a centralized state that tracks:

- All past trials (successes and failures)
- All ongoing trials

This state is essential so that the optimization algorithm can make informed decisions about which hyperparameters to try next.

## A better way to run HPO

This is where Flyte shines.

- There's no need to manage a separate centralized database for state tracking, as every objective run is **cached**, **recorded**, and **recoverable** via Flyte's execution engine.
- The entire HPO process is observable in the UI with full lineage and metadata for each trial.
- Each objective is seeded for reproducibility, enabling deterministic trial results.
- If the main optimization task crashes or is terminated, **Flyte can resume from the last successful or failed trial, making the experiment highly fault-tolerant**.
- Trial functions can be strongly typed, enabling rich, flexible hyperparameter spaces while maintaining strict type safety across trials.

In this example, we combine Flyte with Optuna to optimize a `RandomForestClassifier` on the Iris dataset. Each trial runs in an isolated task, and the optimization process is orchestrated asynchronously, with Flyte handling the underlying scheduling, retries, and caching.

## Declare dependencies

We start by declaring a Python environment using Python 3.13 and specifying our runtime dependencies.

```
# /// script
requires-python = "==3.13"
dependencies = [
   "optuna>=4.0.0,<5.0.0",
   "flyte>=2.0.0b0",
   "scikit-learn==1.7.0",
]
# ///
```

With the environment defined, we begin by importing standard library and third-party modules necessary for both the ML task and distributed execution.

```
import asyncio
import typing
from collections import Counter
from typing import Optional, Union
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

These standard library imports are essential for asynchronous execution (`asyncio`), type annotations (`typing`, `Optional`, `Union`), and aggregating trial state counts (`Counter`).

```
import optuna
from optuna import Trial
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We use Optuna for hyperparameter optimization and several utilities from scikit-learn to prepare data (`load_iris`), define the model (`RandomForestClassifier`), evaluate it (`cross_val_score`), and shuffle the dataset for randomness (`shuffle`).

```
import flyte
import flyte.errors
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

Flyte is our orchestration framework. We use it to define tasks, manage resources, and recover from execution errors.

## Define the task environment

We define a Flyte task environment called `driver`, which encapsulates metadata, compute resources, the container image context needed for remote execution, and caching behavior.

```
driver = flyte.TaskEnvironment(
    name="driver",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="optimizer"),
    cache="auto",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This environment specifies that the tasks will run with 1 CPU and 250Mi of memory, the image is built using the current script (`__file__`), and caching is enabled.

<p>
  You can configure the Flyte task environment to reuse containers across multiple executions by setting the
  <code>reusable</code> field to
  <code>flyte.ReusePolicy(replicas=..., idle_ttl=...)</code>. This is especially useful when the final objective
  computations are short-lived, as it avoids unnecessary container spin-up costs. Learn more about reusable containers
  <a href="../../../user-guide/reusable-containers/">here</a>.
</p>

## Define the optimizer

Next, we define an `Optimizer` class that handles parallel execution of Optuna trials using async coroutines. This class abstracts the full optimization loop and supports concurrent trial execution with live logging.

```
class Optimizer:
    def __init__(
        self,
        objective: callable,
        n_trials: int,
        concurrency: int = 1,
        delay: float = 0.1,
        study: Optional[optuna.Study] = None,
        log_delay: float = 0.1,
    ):
        self.n_trials: int = n_trials
        self.concurrency: int = concurrency
        self.objective: typing.Callable = objective
        self.delay: float = delay
        self.log_delay = log_delay

        self.study = study if study else optuna.create_study()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We pass the `objective` function, number of trials to run (`n_trials`), and maximum parallel trials (`concurrency`). The optional delay throttles execution between trials, while `log_delay` controls how often logging runs. If no existing Optuna Study is provided, a new one is created automatically.

```
    async def log(self):
        while True:
            await asyncio.sleep(self.log_delay)

            counter = Counter()

            for trial in self.study.trials:
                counter[trial.state.name.lower()] += 1

            counts = dict(counter, queued=self.n_trials - len(self))

            # print items in dictionary in a readable format
            formatted = [f"{name}: {count}" for name, count in counts.items()]
            print(f"{'    '.join(formatted)}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This method periodically prints the number of trials in each state (e.g., running, complete, fail). It keeps users informed of ongoing optimization progress and is invoked as a background task when logging is enabled.

![Optuna logging](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/hpo/logging.png)
_Logs are streamed live as the execution progresses._

```
    async def spawn(self, semaphore: asyncio.Semaphore):
        async with semaphore:
            trial: Trial = self.study.ask()

            try:
                print("Starting trial", trial.number)

                params = {
                    "n_estimators": trial.suggest_int("n_estimators", 10, 200),
                    "max_depth": trial.suggest_int("max_depth", 2, 20),
                    "min_samples_split": trial.suggest_float(
                        "min_samples_split", 0.1, 1.0
                    ),
                }

                output = await self.objective(params)

                self.study.tell(trial, output, state=optuna.trial.TrialState.COMPLETE)
            except flyte.errors.RuntimeUserError as e:
                print(f"Trial {trial.number} failed: {e}")

                self.study.tell(trial, state=optuna.trial.TrialState.FAIL)

            await asyncio.sleep(self.delay)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

Each call to `spawn` runs a single Optuna trial. The `semaphore` ensures that only a fixed number of concurrent trials are active at once, respecting the `concurrency` parameter. We first ask Optuna for a new trial and generate a parameter dictionary by querying the trial object for suggested hyperparameters. The trial is then evaluated by the objective function. If successful, we mark it as `COMPLETE`. If the trial fails due to a `RuntimeUserError` from Flyte, we log and record the failure in the Optuna study.

```
    async def __call__(self):
        # create semaphore to manage concurrency
        semaphore = asyncio.Semaphore(self.concurrency)

        # create list of async trials
        trials = [self.spawn(semaphore) for _ in range(self.n_trials)]

        logger: Optional[asyncio.Task] = None
        if self.log_delay:
            logger = asyncio.create_task(self.log())

        # await all trials to complete
        await asyncio.gather(*trials)

        if self.log_delay and logger:
            logger.cancel()
            try:
                await logger
            except asyncio.CancelledError:
                pass
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

The `__call__` method defines the overall async optimization routine. It creates the semaphore, spawns `n_trials` coroutines, and optionally starts the background logging task. All trials are awaited with `asyncio.gather`.

```
    def __len__(self) -> int:
        """Return the number of trials in history."""
        return len(self.study.trials)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This method simply allows us to query the number of trials already associated with the study.

## Define the objective function

The objective task defines how we evaluate a particular set of hyperparameters. It's an async task, allowing for caching, tracking, and recoverability across executions.

```
@driver.task
async def objective(params: dict[str, Union[int, float]]) -> float:
    data = load_iris()
    X, y = shuffle(data.data, data.target, random_state=42)

    clf = RandomForestClassifier(
        n_estimators=params["n_estimators"],
        max_depth=params["max_depth"],
        min_samples_split=params["min_samples_split"],
        random_state=42,
        n_jobs=-1,
    )

    # Use cross-validation to evaluate performance
    score = cross_val_score(clf, X, y, cv=3, scoring="accuracy").mean()

    return score.item()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We use the Iris dataset as a toy classification problem. The input params dictionary contains the trial's hyperparameters, which we unpack into a `RandomForestClassifier`. We shuffle the dataset for randomness, and compute a 3-fold cross-validation accuracy.

## Define the main optimization loop

The optimize task is the main driver of our optimization experiment. It creates the `Optimizer` instance and invokes it.

```
@driver.task
async def optimize(
    n_trials: int = 20,
    concurrency: int = 5,
    delay: float = 0.05,
    log_delay: float = 0.1,
) -> dict[str, Union[int, float]]:
    optimizer = Optimizer(
        objective=objective,
        n_trials=n_trials,
        concurrency=concurrency,
        delay=delay,
        log_delay=log_delay,
        study=optuna.create_study(
            direction="maximize", sampler=optuna.samplers.TPESampler(seed=42)
        ),
    )

    await optimizer()

    best = optimizer.study.best_trial

    print("✅ Best Trial")
    print("  Number :", best.number)
    print("  Params :", best.params)
    print("  Score  :", best.value)

    return best.params
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We configure a `TPESampler` for Optuna and `seed` it for determinism. After running all trials, we extract the best-performing trial and print its parameters and score. Returning the best params allows downstream tasks or clients to use the tuned model.

## Run the experiment

Finally, we include an executable entry point to run this optimization using `flyte.run`.

```
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(optimize, 100, 10)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We load Flyte config from `config.yaml`, launch the optimize task with 100 trials and concurrency of 10, and print a link to view the execution in the Flyte UI.

![HPO execution](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/hpo/execution.png)
_Each objective run is cached, recorded, and recoverable. With concurrency set to 10, only 10 trials execute in parallel at any given time._

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/data-processing ===

# Data processing

Tutorials for large-scale data processing and batching strategies.

### **Data processing > Batching strategies for efficient scaling**

Process millions of items efficiently with resilient, scalable batching patterns built on Flyte v2.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/data-processing/micro-batching ===

<!--

   This file was generated by Makefile.jupyter. Do not edit this file directly.

   The only parts of this file that should be edited are the front matter and the
   comment at the top of the file.

-->

# Batching strategies for efficient scaling

> [!NOTE]
> [View source on GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/batching_patterns/batch_processing.ipynb) | [Run in Google Colab](https://colab.research.google.com/github/unionai/unionai-examples/blob/main/v2/tutorials/batching_patterns/batch_processing.ipynb)

This notebook demonstrates a production-ready pattern for processing millions of items efficiently using Flyte v2's advanced features. You'll learn how to build resilient, scalable workflows that can handle failures gracefully and optimize resource consumption.

## Use case

**The Challenge:** Processing massive datasets (100K to 1M+ items) that require external API calls or long-running operations.

**Real-World Examples:**
- Web scraping large lists of URLs
- Batch inference on millions of data points
- Processing documents through external APIs
- ETL pipelines with rate-limited services
- Data validation against third-party services

**The Problem:** When you have so many inputs that you must:
1. Split them into batches 
2. Submit each batch to an external service and wait for completion
3. Handle failures without losing progress
4. Optimize resource usage across thousands of operations

**Why This Matters:** Without proper batching and checkpointing, a single failure in a million-item workflow could force you to restart from scratch, wasting compute resources and time.

## Goals

**Our Goals:**
1. **Resilience:** Mitigate the impact of batches that take longer or fail
2. **Determinism:** Make operations with external API dependencies predictable and resumable
3. **Efficiency:** Optimize resource consumption through container reuse and parallel processing
4. **Cost Savings:** Minimize wasted compute by checkpointing progress

## Solution architecture

This example demonstrates a production-ready micro-batching pattern that combines some Union features, including:

### 1. Failure transparency with @flyte.trace
The `@flyte.trace` decorator creates automatic checkpoints:
- **What it does:** Records inputs and outputs of decorated functions
- **Why it matters:** If a task fails, it resumes from the last successful checkpoint
- **Result:** No re-execution of completed work

### 2. Reusable containers for efficiency
Instead of creating a new container for each task:
- **Container pools:** Pre-warmed replicas ready to handle work
- **Concurrent processing:** Each replica handles multiple items simultaneously
- **Automatic scaling:** Replicas scale between min/max based on workload
- **Resource optimization:** Dramatically reduced startup overhead

### Key benefits:
- **Automatic checkpointing** at batch and operation boundaries  
- **Resume from last successful point** on any failure  
- **No wasted compute** - never re-execute completed work  
- **Massive parallelism** - process thousands of batches concurrently  
- **Cost efficient** - container reuse minimizes cold-start overhead  

### Architecture flow:
```
1M items → Split into 1,000 batches (1K each)
         ↓
    Parallel processing across reusable container pool
         ↓
    Each batch: Submit → Poll → Checkpoint
         ↓
    Aggregate results from all batches
```

### Architecture diagram

![Micro-batching Architecture](./images/micro-batching.png)

**Diagram shows:**
- Input data split into batches
- Reusable container pool 
- Concurrent processing within each replica 
- Submit and wait phases with `@flyte.trace` checkpoints
- Parallel execution across all batches

## Implementation

### Step 0: Set up the runtime
Prepare the runtime environment for execution

```python
!uv pip install --no-cache --prerelease=allow --upgrade "flyte>=2.0.0b52" "unionai-reuse>=0.1.10"
```

### Step 1: Initialize Flyte configuration

Configure your connection to the Flyte cluster. This tells Flyte where to run your workflows and how to build container images.

**Configuration Options:**
- `endpoint`: Your Flyte cluster URL
- `org`: Your organization name
- `project`: Project to organize workflows
- `domain`: Environment (development, staging, production)
- `image_builder`: Use "remote" to build images on the cluster (no local Docker required)

```python
# Initialize connection to your Flyte cluster
# Replace these values with your own cluster details

import flyte
flyte.init(
    endpoint="https://<MY_TENANT_HOST>",  # Your Union cluster URL
    org="demo",                                     # Your organization
    project="flytesnacks",                               # Your project name
    domain="development",                           # Environment: development/staging/production
    image_builder="remote",                         # Build images on cluster (no local Docker needed)
    auth_type="DeviceFlow",
)
```

```python
# Import required libraries
import asyncio                          # For concurrent async operations
from datetime import timedelta          # For time-based configuration
from pathlib import Path                # For file path handling
from typing import Dict, List           # For type hints

import flyte                            # Main Flyte SDK
from flyte.remote import Run            # For interacting with remote executions
```

```python
# ============================================
# CONFIGURATION: Adjust these for your use case
# ============================================

# Total number of items to process
# In production, this could be the size of your dataset
NUMBER_OF_INPUTS = 1_000_000  # 1 million items

# Size of each batch
# Considerations for choosing batch size:
# - Larger batches: Fewer tasks, more memory per task
# - Smaller batches: More granular checkpointing, better parallelism
# - Recommendation: Start with 1000-10000 depending on item complexity
BATCH_SIZE = 1000

# Example calculations:
# 1M items ÷ 1K batch = 1,000 parallel batch tasks
# Each batch processes 1K items concurrently within its container
```

### Step 2: Define container image

Create a container image specification with all required dependencies.

**Key Dependencies:**
- `flyte>=2.0.0b52`: Flyte v2 SDK for workflow orchestration
- `unionai-reuse>=0.1.10`: Required for Reusable Containers feature

**Note:** You can add any additional packages your tasks need (e.g., `httpx` for API calls, `beautifulsoup4` for web scraping, etc.)

```python
# Define the container image that will run our tasks
# This image will be built once and shared across all task executions
image = (
    flyte.Image.from_debian_base()  # Start with a lightweight Debian base
    .with_pip_packages(
        "flyte>=2.0.0b52",          # Flyte v2 SDK
        "unionai-reuse>=0.1.10"      # Required for reusable containers
        # Add your own dependencies here
    )
)
```

### Step 3: Define task environments

Task environments encapsulate the runtime configuration for tasks. We'll create one with **Reusable Containers** for efficient batch processing.

#### What are reusable containers?

Instead of creating a new Kubernetes Pod for every task execution, Reusable Containers maintain a pool of pre-warmed replicas that can handle multiple tasks sequentially or concurrently.

**Benefits:**
- **Faster execution:** No container startup overhead (can save 10-60 seconds per task)
- **Better resource utilization:** Containers stay warm and handle multiple items
- **Cost savings:** Especially significant for tasks with expensive initialization
- **Concurrent processing:** Each replica can process multiple items simultaneously

```python
# Create a TaskEnvironment with Reusable Containers for batch processing
batch_env = flyte.TaskEnvironment(
    name="batch_processor",  # Name used for Kubernetes pods: batch_processor-<hash>

    # Resource allocation per replica (per pod)
    resources=flyte.Resources(
        memory="2Gi",  # Memory per replica
        cpu="1"        # CPU cores per replica
    ),

    # Reusable container configuration
    reusable=flyte.ReusePolicy(
        # Number of replica pods to maintain
        # (min, max) - scales between these values based on workload
        replicas=(3, 10),  # Start with 3, scale up to 10 as needed

        # Concurrency: How many items each replica processes simultaneously
        # Higher = more throughput per replica, but more memory usage
        concurrency=5,  # Each pod handles 5 concurrent operations

        # How long idle replicas stay alive before being torn down
        idle_ttl=timedelta(minutes=5),  # Keep warm for 5 minutes
    ),

    # Use the container image we defined earlier
    image=image,
)

# CAPACITY CALCULATION:
# With replicas=(3, 10) and concurrency=5:
# - Minimum concurrent processing: 3 replicas × 5 concurrency = 15 operations
# - Maximum concurrent processing: 10 replicas × 5 concurrency = 50 operations
#
# For 1,000 batches with these settings:
# - Best case: 50 batches processing simultaneously
# - Time to process all: ~20 rounds of execution
```

#### Understanding TaskEnvironment parameters

**name:** 
- Used as the prefix for Kubernetes pod names
- Example: `batch_processor-abc123`

**resources:** 
- Compute resources allocated to *each replica*
- Set based on your task's memory and CPU needs
- Tip: Monitor actual usage and adjust accordingly

**replicas (min, max):**
- Flyte autoscales between these values based on workload
- More replicas = more parallel processing capacity
- Consider your cluster's capacity and quota limits

**concurrency:**
- Number of async operations each Python process (per pod) handles simultaneously
- This is *within* each replica, not across replicas
- Higher values increase throughput but require more memory
- Best for I/O-bound tasks (API calls, web scraping)
- For CPU-bound tasks, keep this lower (1-2)

**idle_ttl:**
- Time replicas stay alive without active work before shutdown
- Longer TTL = faster subsequent executions, higher resource costs
- Shorter TTL = lower costs, potential startup delays
- Recommendation: 5-15 minutes for typical workloads

**image:**
- The container image specification with all dependencies
- Built once and reused across all task executions

#### Creating the orchestrator environment

The orchestrator task coordinates all batch processing but doesn't need container reuse since it only runs once per workflow execution.

```python
# Create a separate environment for the orchestrator task
orchestrator_env = flyte.TaskEnvironment(
    name="orchestrator",

    # depends_on: Use the same image as batch_env (avoids rebuilding)
    # Flyte will build batch_env's image first, then reuse it here.
    # This is also needed as the orchestrator task calls batch tasks that use batch_env.
    depends_on=[batch_env],

    # Orchestrator needs more memory to track all batch executions
    # but doesn't need reusable containers (runs once per workflow)
    resources=flyte.Resources(
        memory="4Gi",  # More memory to manage many parallel batches
        cpu="1"        # Single CPU is sufficient for orchestration
    ),

    image=image,  # Same image, different resource allocation
)
```

#### Why two environments?

**Separation of Concerns:**
- **Batch Environment:** Does the heavy lifting (processing items)
  - Needs reusable containers for efficiency
  - Scales horizontally (many replicas)
  - I/O bound operations benefit from concurrency

- **Orchestrator Environment:** Coordinates the workflow
  - Runs once per workflow execution
  - Doesn't need container reuse
  - Needs enough memory to track all batches
  - CPU bound for coordination logic

This separation optimizes both cost and performance.

### Step 4: Define external service interactions

These helper functions simulate interactions with external services (APIs, web scraping, etc.). 

```python
async def submit_to_service(request_id: int) -> str:
    """
    Submit a request to an external service and get a job ID.

    This simulates the "submit" phase of a batch job pattern where you:
    1. Send data to an external service
    2. Receive a job/task ID for tracking
    3. Use that ID to poll for completion later

    PRODUCTION IMPLEMENTATION:
    Replace this simulation with your actual service call:

    ```python
    async with httpx.AsyncClient() as client:
        response = await client.post(
            "https://your-service.com/api/submit",
            json={"request_id": request_id, "data": your_data},
            timeout=30.0
        )
        response.raise_for_status()
        return response.json()["job_id"]
    ```

    Args:
        request_id: Unique identifier for this request

    Returns:
        job_id: Identifier to track this job's progress
    """
    await asyncio.sleep(0.01)  # Simulate network latency
    job_id = f"job_{request_id}"
    return job_id

async def poll_job_status(job_id: str, request_id: int) -> int:
    """
    Poll an external service until a job completes and return results.

    This simulates the "wait" phase where you:
    1. Repeatedly check if a submitted job has completed
    2. Wait between checks to avoid overwhelming the service
    3. Return the final result when ready

    PRODUCTION IMPLEMENTATION:
    Replace this simulation with your actual polling logic:

    ```python
    async with httpx.AsyncClient() as client:
        max_attempts = 60  # 5 minutes with 5-second intervals

        for attempt in range(max_attempts):
            response = await client.get(
                f"https://your-service.com/api/status/{job_id}",
                timeout=10.0
            )
            response.raise_for_status()
            status = response.json()

            if status["state"] == "completed":
                return status["result"]
            elif status["state"] == "failed":
                raise Exception(f"Job {job_id} failed: {status['error']}")

            # Wait before next poll
            await asyncio.sleep(5)

        raise TimeoutError(f"Job {job_id} did not complete in time")
    ```

    Args:
        job_id: The job identifier from submit_to_service
        request_id: Original request ID for logging/tracking

    Returns:
        result: The processed result from the external service
    """
    await asyncio.sleep(0.05)  # Simulate polling + processing time
    return request_id * 2  # Dummy result

# IMPORTANT NOTES:
# 1. Both functions are async - they don't block while waiting
# 2. Add logging for debugging and monitoring
```

### Step 5: Implement the batch processing task

This is the heart of the pattern. The `process_batch` task processes a batch of items with automatic checkpointing using `@flyte.trace`.

#### Key concepts:

**Two-Phase Processing:**
1. **Submit Phase:** Send all items to external service concurrently
2. **Wait Phase:** Poll for completion of all submitted jobs

**Why @flyte.trace?**
- Creates checkpoints at phase boundaries
- If the task fails during wait phase, it resumes from there (doesn't re-submit)
- Enables forward recovery without re-execution

**Concurrency Pattern:**
- Uses `asyncio.gather()` to process all items in a batch simultaneously
- `return_exceptions=True` prevents one failure from stopping the batch
- Each phase completes fully before moving to the next

```python
@batch_env.task  # This task runs in the reusable container pool
async def process_batch(batch_start: int, batch_end: int) -> List[int]:
    """
    Process a single batch of items with checkpointed phases.

    This function demonstrates the core micro-batching pattern with:
    1. Two-phase processing (submit → wait)
    2. Automatic checkpointing via @flyte.trace
    3. Error handling without stopping the entire batch
    4. Concurrent processing within the batch

    Args:
        batch_start: Starting index for this batch (inclusive)
        batch_end: Ending index for this batch (exclusive)

    Returns:
        List of processed results (or -1 for failed items)

    Example:
        process_batch(0, 1000) processes items 0-999
        process_batch(1000, 2000) processes items 1000-1999
    """

    # ========================================
    # PHASE 1: SUBMIT ALL ITEMS TO SERVICE
    # ========================================
    @flyte.trace  # Creates a checkpoint after this phase completes
    async def submit_phase(items: List[int]) -> Dict[int, str]:
        """
        Submit all items concurrently and collect job IDs.

        This function:
        1. Launches submit_to_service() for ALL items simultaneously
        2. Waits for all submissions to complete with asyncio.gather()
        3. Handles errors gracefully (return_exceptions=True)
        4. Maps each request_id to its job_id (or None if failed)

        Why @flyte.trace here:
        - If this phase succeeds but wait_phase fails, we don't re-submit
        - Checkpointed data includes all job_ids for the wait phase
        - Forward recovery from exact failure point

        """

        job_ids = await asyncio.gather(
            *(submit_to_service(request_id=x) for x in items),
            return_exceptions=True  # Don't stop on individual failures
        )

        # Map request IDs to job IDs (or None for failures)
        job_mapping = {}
        for request_id, job_id in zip(items, job_ids):
            if isinstance(job_id, Exception):
                print(f"[ERROR] Submit failed for {request_id}: {job_id}")
                job_mapping[request_id] = None  # Mark as failed
            else:
                job_mapping[request_id] = job_id

        return job_mapping

    # ========================================
    # PHASE 2: WAIT FOR ALL JOBS TO COMPLETE
    # ========================================
    @flyte.trace  # Creates another checkpoint after this phase completes
    async def wait_phase(job_mapping: Dict[int, str]) -> List[int]:
        """
        Poll all submitted jobs until completion.

        This function:
        1. Takes the checkpointed job_mapping from submit_phase
        2. Polls all jobs concurrently
        3. Handles polling errors gracefully
        4. Returns final results

        WHY @flyte.trace HERE:
        - If polling fails partway through, we resume with cached job_mapping
        - Don't re-submit jobs that were already submitted
        - Each successful poll is checkpointed

        ERROR HANDLING:
        - Jobs that failed in submit_phase (None) are skipped
        - Polling failures are caught and marked as -1
        - The batch continues even if some items fail
        """
        # Poll ALL jobs concurrently
        results = await asyncio.gather(
            *(
                poll_job_status(job_id=job_id, request_id=request_id)
                if job_id is not None  # Only poll successfully submitted jobs
                else asyncio.sleep(0)   # Skip failed submissions
                for request_id, job_id in job_mapping.items()
            ),
            return_exceptions=True  # Don't stop on individual failures
        )

        # Process results and handle errors
        processed_results = []
        for request_id, result in zip(job_mapping.keys(), results):
            if isinstance(result, Exception):
                print(f"[ERROR] Wait failed for {request_id}: {result}")
                processed_results.append(-1)  # Mark as failed
            else:
                processed_results.append(result)

        return processed_results

    # ========================================
    # EXECUTE BOTH PHASES SEQUENTIALLY
    # ========================================
    # Create the list of items for this batch
    items = list(range(batch_start, batch_end))

    # Phase 1: Submit all items and get job IDs (checkpointed)
    job_mapping = await submit_phase(items)

    # Phase 2: Wait for all jobs to complete (checkpointed)
    results = await wait_phase(job_mapping)

    # Log batch completion stats
    successful = len([r for r in results if r != -1])
    print(f"Batch {batch_start}-{batch_end}: {successful}/{len(results)} successful")

    return results

# ========================================
# CHECKPOINT & RECOVERY BEHAVIOR
# ========================================
#
# Scenario 1: Task fails during submit_phase
# → Retries resume from last checkpoint
#
# Scenario 2: Task fails after submit_phase completes
# → Resumes directly to wait_phase with cached job_mapping
# → No re-submissions!
#
# Scenario 3: Task fails during wait_phase
# → Resumes wait_phase with cached job_mapping
# → Already-polled jobs are not polled again (Flyte makes operations idempotent)

```

#### Understanding @flyte.trace

**Why use it for both phases:**
- Submit phase checkpoint = "These jobs were submitted successfully"
- Wait phase checkpoint = "These results were retrieved successfully"
- Without it: A failure in submit or wait phase would re-submit or re-poll everything

**Best Practices:**
- Use `@flyte.trace` for non-deterministic operations (API calls, random operations)
- Don't use it for pure, deterministic functions (unnecessary overhead)
- Ensure traced functions are idempotent when possible
- Keep traced function signatures simple (serializable inputs/outputs)

See the [Traces](/docs/v2/union//user-guide/task-programming/traces/) docs for more details on how it works

### Step 6: Implement the orchestrator workflow

The orchestrator is the top-level task that:
1. Splits the total workload into batches
2. Launches all batches in parallel
3. Aggregates results from all batches
4. Reports overall statistics

**This is where the magic happens:** All batches run concurrently, limited only by your reusable container pool configuration.

```python
@orchestrator_env.task  # Runs in the orchestrator environment (no reuse)
async def microbatch_workflow(
    total_items: int = NUMBER_OF_INPUTS,
    batch_size: int = BATCH_SIZE,
) -> List[int]:
    """
    Main task orchestrating the entire micro-batching process.

    This task:
    1. Calculates optimal batch distribution
    2. Launches all batch tasks in parallel
    3. Aggregates results from completed batches
    4. Provides comprehensive execution statistics

    Args:
        total_items: Total number of items to process (default: 1M)
        batch_size: Number of items per batch (default: 1K)

    Returns:
        Aggregated results from all batches (list of processed values)

    Execution Flow:
        1M items → 1,000 batches → Parallel execution → Aggregated results

    Resource Usage:
        - This task: 4Gi memory, 1 CPU (orchestration only)
        - Each batch task: 2Gi memory, 1 CPU (from batch_env)
        - Reusable containers handle actual processing
    """

    # ========================================
    # STEP 1: CALCULATE BATCH DISTRIBUTION
    # ========================================
    # Split total items into batch ranges: [(0, 1000), (1000, 2000), ...]
    batches = [
        (start, min(start + batch_size, total_items))
        for start in range(0, total_items, batch_size)
    ]

    print(f"Processing {total_items:,} items in {len(batches):,} batches of size {batch_size:,}")
    print(f"Expected parallelism: {batch_env.reusable.replicas[0]}-{batch_env.reusable.replicas[1]} replicas")
    print(f"Concurrency per replica: {batch_env.reusable.concurrency}")
    print(f"Max simultaneous batches: {batch_env.reusable.replicas[1] * batch_env.reusable.concurrency}")

    # ========================================
    # STEP 2: LAUNCH ALL BATCHES IN PARALLEL
    # ========================================
    # This is the key to massive parallelism:
    # - Creates as many async tasks as concurrent operations your API supports
    # - All execute concurrently within container pool limits
    # - Reusable containers handle the workload efficiently
    # - return_exceptions=True prevents one batch failure from stopping all

    print(f"\n Launching {len(batches):,} parallel batch tasks...")

    # Rate limiter to control API throughput
    max_concurrent_batches = 10  # Adjust based on API rate limits
    semaphore = asyncio.Semaphore(max_concurrent_batches)

    async def rate_limited_batch(start: int, end: int):
        """Wrapper to enforce rate limiting on batch processing."""
        async with semaphore:
            return await process_batch(batch_start=start, batch_end=end)

    batch_results = await asyncio.gather(
        *(rate_limited_batch(start, end) for start, end in batches),
        return_exceptions=True  # Isolated failure handling per batch
    )
    # ========================================
    # STEP 3: AGGREGATE RESULTS & STATISTICS
    # ========================================
    all_results = []
    failed_batches = 0
    failed_items = 0

    for i, batch_result in enumerate(batch_results):
        if isinstance(batch_result, Exception):
            # Entire batch failed (task-level failure)
            print(f"[ERROR] Batch {i} failed completely: {batch_result}")
            failed_batches += 1
        else:
            # Batch completed, but individual items may have failed
            all_results.extend(batch_result)
            failed_items += len([r for r in batch_result if r == -1])

    # Calculate final statistics
    success_count = len([r for r in all_results if r != -1])
    total_processed = len(all_results)

    # ========================================
    # STEP 4: REPORT EXECUTION SUMMARY
    # ========================================
    print(f"\n{'=' * 60}")
    print(f" Execution summary")
    print(f"{'=' * 60}")
    print(f"Total items requested:    {total_items:,}")
    print(f"Total batches:            {len(batches):,}")
    print(f"Batch size:               {batch_size:,}")
    print(f"")
    print(f" Successful items:       {success_count:,}")
    print(f" Failed items:           {failed_items:,}")
    print(f" Failed batches:         {failed_batches}")
    print(f"")
    print(f" Success rate:           {success_count / total_items * 100:.2f}%")
    print(f" Items processed:        {total_processed:,} / {total_items:,}")
    print(f"{'=' * 60}\n")

    return all_results

# ========================================
# EXECUTION BEHAVIOR & OPTIMIZATION
# ========================================
#
# Parallel Execution Pattern:
# ┌─────────────────────────────────────────────────┐
# │ Orchestrator Task (1 pod, 4Gi, 1 CPU)         │
# │                                                 │
# │ Launches 1,000 process_batch() invocations     │
# └─────────────────┬───────────────────────────────┘
#                   │
#           ┌───────┴────────┐
#           ▼                ▼
#   ┌──────────────┐  ┌──────────────┐
#   │ Replica 1    │  │ Replica 2    │  ... up to 10 replicas
#   │ 2Gi, 1 CPU   │  │ 2Gi, 1 CPU   │
#   │              │  │              │
#   │ Concurrency: │  │ Concurrency: │
#   │ 5 batches    │  │ 5 batches    │
#   └──────────────┘  └──────────────┘
#
# With 10 replicas × 5 concurrency = 50 batches processing simultaneously
# Time to complete 1,000 batches ≈ 1,000 / 50 = 20 waves
#
# Optimization Tips:
# 1. Increase replicas for more parallelism (if cluster allows)
# 2. Adjust concurrency based on task I/O vs CPU profile
# 3. Tune batch_size to balance granularity vs overhead
# 4. Monitor actual execution to find bottlenecks
# 5. Use Flyte UI to visualize execution patterns
```

### Step 7: Execute the workflow

Now let's run the entire workflow remotely on your Union cluster.

**Execution Options:**
- **Remote execution** (shown below): Runs on the Union cluster
- **Local execution**: Use `flyte.with_runcontext(mode="local").run()` for testing

**What happens during execution:**
1. Flyte builds the container image (if needed)
2. Creates the orchestrator pod
3. Orchestrator calculates batches and launches batch tasks
4. Reusable container pool starts spinning up (min: 3 replicas in this example)
5. Batches are distributed across available replicas
6. Pool scales up to max replicas (10 in this example) as needed
7. Results are aggregated and returned

```python
if __name__ == "__main__":
    print("=" * 60)
    print(" STARTING MICRO-BATCHING WORKFLOW")
    print("=" * 60)
    print(f"Total items to process: {NUMBER_OF_INPUTS:,}")
    print(f"Batch size: {BATCH_SIZE:,}")
    print(f"Expected batches: {NUMBER_OF_INPUTS // BATCH_SIZE:,}")
    print("=" * 60)
    print()

    # Launch the workflow remotely (runs on Flyte cluster)
    # The 'await' is needed because flyte.run.aio() is async
    r = await flyte.run.aio(microbatch_workflow)

    # Print execution details
    print(f"\n{'=' * 60}")
    print(f" EXECUTION STARTED")
    print(f"{'=' * 60}")
    # print(f"Run name: {r.name}")  # Internal run identifier
    print(f"🔗 Execution URL: {r.url}")
    print(f"\n💡 Visit the URL above to:")
    print(f"   • View the execution graph and task timeline")
    print(f"   • Monitor progress in real-time")
    print(f"   • See trace checkpoints in action")
    print(f"   • Inspect logs for each batch")
    print(f"   • Analyze resource utilization")
    print(f"{'=' * 60}\n")

# ========================================
# MONITORING AND DEBUGGING TIPS
# ========================================
#
# 1. View Execution in UI:
#    - Click the execution URL printed above
#    - See visual graph of all batch tasks
#    - Monitor which batches are running/completed/failed
#
# 2. Check Logs:
#    - Click on individual batch tasks in the graph
#    - View stdout/stderr for debugging
#    - See checkpoint/recovery messages
#
# 3. Resource Utilization:
#    - Navigate to Resources tab in UI
#    - Monitor CPU/memory usage per task
#    - Identify bottlenecks or over-provisioning
#
# 4. Trace Visualization:
#    - Expand batch tasks to see trace checkpoints
#    - Verify submit_phase and wait_phase separately
#    - Understand recovery points on failures
#
# 5. Performance Analysis:
#    - Check task durations in timeline view
#    - Identify slow batches or stragglers
#    - Optimize batch_size or concurrency based on results
```

On execution, this is what this example looks like at the Kubernetes level:

![](./images/reusable-containers-k8s.png)

This is, 10 replicas (as defined in the `TaskEnvironment`) and the driver Pod that runs the parent task (`a0`). [Learn more about the parent task](/docs/v2/union//user-guide/considerations/#driver-pod-requirements).

## Batch size selection

**Finding the optimal batch size:**
- **Too small:** More overhead from task management, less efficient
- **Too large:** Longer recovery time on failures, higher memory usage

**Factors to consider:**
- Item processing time (longer = larger batches)
- Memory consumption per item (higher = smaller batches)
- Failure tolerance (critical = smaller batches for faster recovery)
- Total workload size (larger total = can use larger batches)

Read the [Optimization strategies](/docs/v2/union//user-guide/run-scaling/scale-your-workflows/#2-batch-workloads-to-reduce-overhead) page to understand the overheads associated with an execution and how to choose the appropiate batch size.

## Summary

This notebook demonstrated a production-ready micro-batching pattern for Flyte v2 that combines:

1. **Reusable Containers** for efficiency
2. **@flyte.trace** for checkpointing and recovery
3. **Massive parallelism** via async/await
4. **Robust error handling** for resilience

**Key Takeaways:**
- Use `@flyte.trace` for non-deterministic operations
- Monitor resource usage and optimize incrementally
- Choose the right pattern for your specific use case

**Next Steps:**
- Adapt this pattern to your specific use case
- Replace mock functions with real API calls
- Test with your actual dataset
- Monitor and optimize based on production metrics

