Skip to content

Commit

Permalink
fix(langchain): check chain.invoke() argument name [backport #8835 to…
Browse files Browse the repository at this point in the history
… 2.8] (#8987)

Backport #8835 to 2.8.

This PR fixes the langchain integration's patched chain method to check
for the correct input argument name.

In `LangChain<0.1`, we patch `langchain.Chain.__call__()`, which uses
`inputs: Union[Dict[str, Any], str]` as the argument to the chain
invocation. However in `LangChain>=0.1`, we patch
`langchain.Chain.invoke()`, which uses `input: Dict[str, Any]` as the
argument to the chain invocation. We use the same traced function to
patch both methods, but this subtle change broke our argument parsing,
which expected the name `inputs` instead of `input`.

## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.
- [x] If change touches code that signs or publishes builds or packages,
or handles credentials of any kind, I've requested a review from
`@DataDog/security-design-and-guidance`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [X] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance

policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
Yun-Kim committed Apr 16, 2024
1 parent f74a0fd commit 7e49c10
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 2 deletions.
10 changes: 8 additions & 2 deletions ddtrace/contrib/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,10 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs):
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
final_outputs = {}
try:
inputs = get_argument_value(args, kwargs, 0, "inputs")
if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
inputs = get_argument_value(args, kwargs, 0, "input")
else:
inputs = get_argument_value(args, kwargs, 0, "inputs")
if not isinstance(inputs, dict):
inputs = {instance.input_keys[0]: inputs}
if integration.is_pc_sampled_span(span):
Expand Down Expand Up @@ -645,7 +648,10 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain")
final_outputs = {}
try:
inputs = get_argument_value(args, kwargs, 0, "inputs")
if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
inputs = get_argument_value(args, kwargs, 0, "input")
else:
inputs = get_argument_value(args, kwargs, 0, "inputs")
if not isinstance(inputs, dict):
inputs = {instance.input_keys[0]: inputs}
if integration.is_pc_sampled_span(span):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
langchain: This fix resolves an issue where tracing ``Chain.invoke()`` instead of ``Chain.__call__()`` resulted in
the an ``ArgumentError`` due to an argument name change for inputs between the two methods.
18 changes: 18 additions & 0 deletions tests/contrib/langchain/test_langchain_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,24 @@ def test_openai_math_chain_sync(langchain, langchain_openai, request_vcr):
chain.invoke("what is two raised to the fifty-fourth power?")


@pytest.mark.snapshot(token="tests.contrib.langchain.test_langchain_community.test_chain_invoke")
def test_chain_invoke_dict_input(langchain, langchain_openai, request_vcr):
prompt_template = "what is {base} raised to the fifty-fourth power?"
prompt = langchain.prompts.PromptTemplate(input_variables=["adjective"], template=prompt_template)
chain = langchain.chains.LLMChain(llm=langchain_openai.OpenAI(temperature=0), prompt=prompt)
with request_vcr.use_cassette("openai_math_chain_sync.yaml"):
chain.invoke(input={"base": "two"})


@pytest.mark.snapshot(token="tests.contrib.langchain.test_langchain_community.test_chain_invoke")
def test_chain_invoke_str_input(langchain, langchain_openai, request_vcr):
prompt_template = "what is {base} raised to the fifty-fourth power?"
prompt = langchain.prompts.PromptTemplate(input_variables=["adjective"], template=prompt_template)
chain = langchain.chains.LLMChain(llm=langchain_openai.OpenAI(temperature=0), prompt=prompt)
with request_vcr.use_cassette("openai_math_chain_sync.yaml"):
chain.invoke("two")


@pytest.mark.asyncio
@pytest.mark.snapshot
async def test_openai_math_chain_async(langchain, langchain_openai, request_vcr):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[[
{
"name": "langchain.request",
"service": "",
"resource": "langchain.chains.llm.LLMChain",
"trace_id": 0,
"span_id": 1,
"parent_id": 0,
"type": "",
"error": 0,
"meta": {
"_dd.p.dm": "-0",
"_dd.p.tid": "660c678a00000000",
"langchain.request.inputs.base": "two",
"langchain.request.prompt": "what is {base} raised to the fifty-fourth power?",
"langchain.request.type": "chain",
"langchain.response.outputs.base": "two",
"langchain.response.outputs.text": "```text\\n2**54\\n```\\n...numexpr.evaluate(\"2**54\")...\\n",
"language": "python",
"runtime-id": "00697e1e790d47bb9828bc4ab549fac1"
},
"metrics": {
"_dd.measured": 1,
"_dd.top_level": 1,
"_dd.tracer_kr": 1.0,
"_sampling_priority_v1": 1,
"langchain.tokens.completion_tokens": 19,
"langchain.tokens.prompt_tokens": 202,
"langchain.tokens.total_cost": 0.00034100000000000005,
"langchain.tokens.total_tokens": 221,
"process_id": 61792
},
"duration": 31126000,
"start": 1712088970885366000
},
{
"name": "langchain.request",
"service": "",
"resource": "langchain_openai.llms.base.OpenAI",
"trace_id": 0,
"span_id": 2,
"parent_id": 1,
"type": "llm",
"error": 0,
"meta": {
"langchain.request.api_key": "...key>",
"langchain.request.model": "gpt-3.5-turbo-instruct",
"langchain.request.openai.parameters.frequency_penalty": "0",
"langchain.request.openai.parameters.max_tokens": "256",
"langchain.request.openai.parameters.model_name": "gpt-3.5-turbo-instruct",
"langchain.request.openai.parameters.n": "1",
"langchain.request.openai.parameters.presence_penalty": "0",
"langchain.request.openai.parameters.temperature": "0.0",
"langchain.request.openai.parameters.top_p": "1",
"langchain.request.prompts.0": "what is two raised to the fifty-fourth power?",
"langchain.request.provider": "openai",
"langchain.request.type": "llm",
"langchain.response.completions.0.finish_reason": "stop",
"langchain.response.completions.0.logprobs": "None",
"langchain.response.completions.0.text": "```text\\n2**54\\n```\\n...numexpr.evaluate(\"2**54\")...\\n"
},
"metrics": {
"_dd.measured": 1,
"langchain.tokens.completion_tokens": 19,
"langchain.tokens.prompt_tokens": 202,
"langchain.tokens.total_cost": 0.00034100000000000005,
"langchain.tokens.total_tokens": 221
},
"duration": 25786000,
"start": 1712088970890577000
}]]

0 comments on commit 7e49c10

Please sign in to comment.