Skip to content

Commit 9e48c66

Browse files
authored
Merge pull request #15 from jlowin/one-step
Consolidate AI logic
2 parents f43d32f + f237a0d commit 9e48c66

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

src/ai_labeler/ai.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def validate_labels(result: list[str]):
3232

3333
class Reasoning(BaseModel):
3434
label_name: str
35-
reasoning: str
35+
# reasoning: str
36+
should_apply: bool
3637

3738
reasoning = cf.run(
3839
"""
@@ -49,8 +50,10 @@ class Reasoning(BaseModel):
4950
and the label's instructions into account. Some labels will have
5051
specific instructions about when to apply them, or whether to apply them
5152
at all. Be sure to reference all relevant context and instructions in
52-
your reasoning. You do not need to reason about labels that are
53-
obviously irrelevant.
53+
your reasoning.
54+
55+
You do not need to return reasoning about labels that are obviously
56+
irrelevant.
5457
""",
5558
instructions=instructions,
5659
result_type=list[Reasoning],
@@ -65,18 +68,23 @@ class Reasoning(BaseModel):
6568
model_kwargs=dict(tool_choice="required"), # prevent chatting
6669
)
6770

68-
decision = cf.run(
69-
"""
70-
Based on the reasoning for each label, return the list of labels that
71-
should be applied. If no labels apply, return an empty list.
72-
""",
73-
result_type=list[str],
74-
result_validator=validate_labels,
75-
context={"reasoning": reasoning, "available_labels": labels},
76-
agents=[labeler],
77-
completion_tools=["SUCCEED"], # the task can not be marked as failed
78-
model_kwargs=dict(tool_choice="required"), # prevent chatting
79-
)
71+
decision = [r.label_name for r in reasoning if r.should_apply]
72+
73+
# --- old two-step approach. Adding `should_apply` to the reasoning model
74+
# appears to match performance in a single step.
75+
#
76+
# decision = cf.run(
77+
# """
78+
# Based on the reasoning for each label, return the list of labels that
79+
# should be applied. If no labels apply, return an empty list.
80+
# """,
81+
# result_type=list[str],
82+
# result_validator=validate_labels,
83+
# context={"reasoning": reasoning, "available_labels": labels},
84+
# agents=[labeler],
85+
# completion_tools=["SUCCEED"], # the task can not be marked as failed
86+
# model_kwargs=dict(tool_choice="required"), # prevent chatting
87+
# )
8088

8189
print(f"Available labels: {dict(enumerate(labels))}")
8290
print(f"\n\nReasoning: {reasoning}")

0 commit comments

Comments
 (0)