Skip to content

Filter Operation

The Filter operator evaluates an NL condition on each item in a column.

Core Implementation

nirvana.ops.filter.FilterOperation(user_instruction: str = '', input_columns: list[str] = [], context: list[dict] | str | None = None, model: str | None = None, tool: Callable | BaseTool | None = None, strategy: Literal['plain', 'fewshot', 'self-refine'] = 'plain', limit: int | None = None, rate_limit: int = 16, assertions: list[Callable] | None = [])

Bases: BaseOperation

Filter operator: Uses an LLM to evaluate a natural language predicate on a column

Source code in nirvana/ops/filter.py
def __init__(
    self,
    user_instruction: str = "",
    input_columns: list[str] = [],
    context: list[dict] | str | None = None,
    model: str | None = None,
    tool: Callable | BaseTool | None = None,
    strategy: Literal["plain", "fewshot", "self-refine"] = "plain",
    limit: int | None = None,
    rate_limit: int = 16,
    assertions: list[Callable] | None = [],
):
    if tool and not isinstance(tool, BaseTool):
        tool = FunctionCallTool.from_function(func=tool)

    super().__init__(
        op_name="filter",
        user_instruction=user_instruction,
        context=context,
        model=model,
        tool=tool,
        strategy=strategy,
        limit=limit,
        rate_limit=rate_limit,
        assertions=assertions,
    )
    self.prompter = FilterPrompter()
    self.input_columns = input_columns

Attributes

strategy_options = ['plain', 'fewshot', 'self_refine'] class-attribute instance-attribute
prompter = FilterPrompter() instance-attribute
input_columns = input_columns instance-attribute
dependencies: list[str] property
generated_fields: list[str] property
op_kwargs: dict property

Functions

execute(input_data: pd.DataFrame, **kwargs) async
Source code in nirvana/ops/filter.py
async def execute(
    self, 
    input_data: pd.DataFrame,
    **kwargs
):
    if self.user_instruction is None and not self.has_udf():
        raise ValueError("Neither `user_instruction` nor `func` is given.")

    if input_data.empty:
        return FilterOpOutputs()

    processed_data = input_data[self.input_columns]
    dtypes = []
    for col in self.input_columns:
        if isinstance(input_data[col].dtype, ImageDtype):
            dtypes.append("image")
        else:
            dtypes.append("text")

    if self.strategy == "plain":
        execution_func = functools.partial(self._execute_by_plain_llm, dtypes=dtypes, field_name=self.input_columns[0], model=self.model, **kwargs)
    elif self.strategy == "fewshot":
        assert self.context is not None, "Few-shot examples must be provided in the context for in-context learning."
        demos = self.context
        execution_func = functools.partial(self._execute_by_fewshot_llm, dtypes=dtypes, demos=demos, field_name=self.input_columns[0], model=self.model, **kwargs)
    elif self.strategy == "self_refine":
        execution_func = functools.partial(self._execute_by_self_refine, dtypes=dtypes, field_name=self.input_columns[0], model=self.model, **kwargs)
    else:
        raise ValueError(f"The optional strategies available for filter are {self.strategy_options}. Strategy {self.strategy} is not supported.")

    # Create tasks for all data points
    tasks = []
    for _, data in processed_data.iterrows():
        if data.empty:
            tasks.append(asyncio.create_task(asyncio.sleep(0, result=(False, 0.0))))
        elif self.has_udf():
            tasks.append(asyncio.create_task(self._execute_by_func(data, self.user_instruction, self.tool, execution_func)))
        else:
            tasks.append(asyncio.create_task(execution_func(data, self.user_instruction)))

    # Wait for all tasks to complete
    if self.limit is not None and self.limit <= 0:
        warnings.warn("The limit should be positive. To execute, the limit will be ignored.")
        self.limit = None

    token_cost = 0.0
    filter_outputs: list[bool] = []
    if self.limit is not None:
        num_passed_records: int = 0
        reach_limit: bool = False
        for i in range(0, len(tasks), self.limit):
            if reach_limit:
                break
            batch_tasks = tasks[i:i + self.limit]
            batch_results = await asyncio.gather(*batch_tasks)
            token_cost += sum([result[1] for result in batch_results])
            for result, _ in batch_results:
                filter_outputs.append(result)
                if result:
                    num_passed_records += 1
                if num_passed_records >= self.limit:
                    reach_limit = True
                    break
        num_remaining_records = len(processed_data) - len(filter_outputs)
        if num_remaining_records > 0:
            filter_outputs.extend([False] * num_remaining_records)
    else:
        results = await asyncio.gather(*tasks)
        filter_outputs = [result[0] for result in results]
        token_cost = sum([result[1] for result in results])

    return FilterOpOutputs(
        output=filter_outputs,
        cost=token_cost
    )

Output Class

nirvana.ops.filter.FilterOpOutputs(cost: float = 0.0, output: Iterable[bool] = None) dataclass

Bases: BaseOpOutputs

Attributes

output: Iterable[bool] = None class-attribute instance-attribute

Functions

__add__(other: FilterOpOutputs)
Source code in nirvana/ops/filter.py
def __add__(self, other: "FilterOpOutputs"):
    return FilterOpOutputs(
        output=self.output + other.output,
        cost=self.cost + other.cost
    )

Function Wrapper

nirvana.ops.filter

filter_wrapper(input_data: DataFrame, user_instruction: str = None, input_columns: list[str] = None, func: Callable = None, context: list[dict] | str | None = None, model: str | None = None, strategy: Literal['plain', 'fewshot', 'self-refine'] = 'plain', limit: int | None = None, rate_limit: int = 16, assertions: list[Callable] | None = [], **kwargs)

A function wrapper for filter operation

Parameters:

Name Type Description Default
input_data DataFrame

Input dataframe

required
user_instruction str

User instruction. Defaults to None.

None
input_columns list[str]

Input columns. Defaults to None.

None
func Callable

User function. Defaults to None.

None
context list[dict] | str

Context. Defaults to None.

None
model str

Model. Defaults to None.

None
strategy Literal['plain', 'fewshot', 'self-refine']

Strategy. Defaults to "plain".

'plain'
limit int

Maximum number of outputs to produce before stopping.

None
rate_limit int

Rate limit. Defaults to 16.

16
assertions list[Callable]

Assertions. Defaults to [].

[]
**kwargs

Additional keyword arguments for OpenAI Clent.

{}
Source code in nirvana/ops/filter.py
def filter_wrapper(
    input_data: pd.DataFrame, 
    user_instruction: str = None,
    input_columns: list[str] = None,
    func: Callable = None,
    context: list[dict] | str | None = None,
    model: str | None = None,
    strategy: Literal["plain", "fewshot", "self-refine"] = "plain",
    limit: int | None = None,
    rate_limit: int = 16,
    assertions: list[Callable] | None = [],
    **kwargs
):
    """
    A function wrapper for filter operation

    Args:
        input_data (pd.DataFrame): Input dataframe
        user_instruction (str, optional): User instruction. Defaults to None.
        input_columns (list[str], optional): Input columns. Defaults to None.
        func (Callable, optional): User function. Defaults to None.
        context (list[dict] | str, optional): Context. Defaults to None.
        model (str, optional): Model. Defaults to None.
        strategy (Literal["plain", "fewshot", "self-refine"], optional): Strategy. Defaults to "plain".
        limit (int): Maximum number of outputs to produce before stopping.
        rate_limit (int, optional): Rate limit. Defaults to 16.
        assertions (list[Callable], optional): Assertions. Defaults to [].
        **kwargs: Additional keyword arguments for OpenAI Clent.
    """

    filter_op = FilterOperation(
        user_instruction=user_instruction,
        input_columns=input_columns,
        context=context,
        model=model,
        tool=FunctionCallTool.from_function(func=func) if func else None,
        strategy=strategy,
        limit=limit,
        rate_limit=rate_limit,
        assertions=assertions,
    )
    outputs = asyncio.run(filter_op.execute(
        input_data=input_data,
        **kwargs
    ))
    return outputs