diff --git a/ddtrace/contrib/langchain/patch.py b/ddtrace/contrib/langchain/patch.py index 920a3239483..c8658761db6 100644 --- a/ddtrace/contrib/langchain/patch.py +++ b/ddtrace/contrib/langchain/patch.py @@ -597,7 +597,10 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs): span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain") final_outputs = {} try: - inputs = get_argument_value(args, kwargs, 0, "inputs") + if SHOULD_PATCH_LANGCHAIN_COMMUNITY: + inputs = get_argument_value(args, kwargs, 0, "input") + else: + inputs = get_argument_value(args, kwargs, 0, "inputs") if not isinstance(inputs, dict): inputs = {instance.input_keys[0]: inputs} if integration.is_pc_sampled_span(span): @@ -645,7 +648,10 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs): span = integration.trace(pin, "%s.%s" % (instance.__module__, instance.__class__.__name__), interface_type="chain") final_outputs = {} try: - inputs = get_argument_value(args, kwargs, 0, "inputs") + if SHOULD_PATCH_LANGCHAIN_COMMUNITY: + inputs = get_argument_value(args, kwargs, 0, "input") + else: + inputs = get_argument_value(args, kwargs, 0, "inputs") if not isinstance(inputs, dict): inputs = {instance.input_keys[0]: inputs} if integration.is_pc_sampled_span(span): diff --git a/releasenotes/notes/fix-langchain-chain-invoke-792219fb95ac1889.yaml b/releasenotes/notes/fix-langchain-chain-invoke-792219fb95ac1889.yaml new file mode 100644 index 00000000000..1cc28b583f6 --- /dev/null +++ b/releasenotes/notes/fix-langchain-chain-invoke-792219fb95ac1889.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + langchain: This fix resolves an issue where tracing ``Chain.invoke()`` instead of ``Chain.__call__()`` resulted in + the an ``ArgumentError`` due to an argument name change for inputs between the two methods. diff --git a/tests/contrib/langchain/test_langchain_community.py b/tests/contrib/langchain/test_langchain_community.py index 9f5d700e901..cf23c2684bc 100644 --- a/tests/contrib/langchain/test_langchain_community.py +++ b/tests/contrib/langchain/test_langchain_community.py @@ -484,6 +484,24 @@ def test_openai_math_chain_sync(langchain, langchain_openai, request_vcr): chain.invoke("what is two raised to the fifty-fourth power?") +@pytest.mark.snapshot(token="tests.contrib.langchain.test_langchain_community.test_chain_invoke") +def test_chain_invoke_dict_input(langchain, langchain_openai, request_vcr): + prompt_template = "what is {base} raised to the fifty-fourth power?" + prompt = langchain.prompts.PromptTemplate(input_variables=["adjective"], template=prompt_template) + chain = langchain.chains.LLMChain(llm=langchain_openai.OpenAI(temperature=0), prompt=prompt) + with request_vcr.use_cassette("openai_math_chain_sync.yaml"): + chain.invoke(input={"base": "two"}) + + +@pytest.mark.snapshot(token="tests.contrib.langchain.test_langchain_community.test_chain_invoke") +def test_chain_invoke_str_input(langchain, langchain_openai, request_vcr): + prompt_template = "what is {base} raised to the fifty-fourth power?" + prompt = langchain.prompts.PromptTemplate(input_variables=["adjective"], template=prompt_template) + chain = langchain.chains.LLMChain(llm=langchain_openai.OpenAI(temperature=0), prompt=prompt) + with request_vcr.use_cassette("openai_math_chain_sync.yaml"): + chain.invoke("two") + + @pytest.mark.asyncio @pytest.mark.snapshot async def test_openai_math_chain_async(langchain, langchain_openai, request_vcr): diff --git a/tests/snapshots/tests.contrib.langchain.test_langchain_community.test_chain_invoke.json b/tests/snapshots/tests.contrib.langchain.test_langchain_community.test_chain_invoke.json new file mode 100644 index 00000000000..c2ddd110d58 --- /dev/null +++ b/tests/snapshots/tests.contrib.langchain.test_langchain_community.test_chain_invoke.json @@ -0,0 +1,71 @@ +[[ + { + "name": "langchain.request", + "service": "", + "resource": "langchain.chains.llm.LLMChain", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "", + "error": 0, + "meta": { + "_dd.p.dm": "-0", + "_dd.p.tid": "660c678a00000000", + "langchain.request.inputs.base": "two", + "langchain.request.prompt": "what is {base} raised to the fifty-fourth power?", + "langchain.request.type": "chain", + "langchain.response.outputs.base": "two", + "langchain.response.outputs.text": "```text\\n2**54\\n```\\n...numexpr.evaluate(\"2**54\")...\\n", + "language": "python", + "runtime-id": "00697e1e790d47bb9828bc4ab549fac1" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "langchain.tokens.completion_tokens": 19, + "langchain.tokens.prompt_tokens": 202, + "langchain.tokens.total_cost": 0.00034100000000000005, + "langchain.tokens.total_tokens": 221, + "process_id": 61792 + }, + "duration": 31126000, + "start": 1712088970885366000 + }, + { + "name": "langchain.request", + "service": "", + "resource": "langchain_openai.llms.base.OpenAI", + "trace_id": 0, + "span_id": 2, + "parent_id": 1, + "type": "llm", + "error": 0, + "meta": { + "langchain.request.api_key": "...key>", + "langchain.request.model": "gpt-3.5-turbo-instruct", + "langchain.request.openai.parameters.frequency_penalty": "0", + "langchain.request.openai.parameters.max_tokens": "256", + "langchain.request.openai.parameters.model_name": "gpt-3.5-turbo-instruct", + "langchain.request.openai.parameters.n": "1", + "langchain.request.openai.parameters.presence_penalty": "0", + "langchain.request.openai.parameters.temperature": "0.0", + "langchain.request.openai.parameters.top_p": "1", + "langchain.request.prompts.0": "what is two raised to the fifty-fourth power?", + "langchain.request.provider": "openai", + "langchain.request.type": "llm", + "langchain.response.completions.0.finish_reason": "stop", + "langchain.response.completions.0.logprobs": "None", + "langchain.response.completions.0.text": "```text\\n2**54\\n```\\n...numexpr.evaluate(\"2**54\")...\\n" + }, + "metrics": { + "_dd.measured": 1, + "langchain.tokens.completion_tokens": 19, + "langchain.tokens.prompt_tokens": 202, + "langchain.tokens.total_cost": 0.00034100000000000005, + "langchain.tokens.total_tokens": 221 + }, + "duration": 25786000, + "start": 1712088970890577000 + }]]