Skip to content

Commit

Permalink
[FLINK-35832][Table SQL / Planner] Fix IFNULL function returns incorr…
Browse files Browse the repository at this point in the history
…ect result

This closes #25099
  • Loading branch information
dylanhz authored and lsyldliu committed Jul 19, 2024
1 parent 255abc7 commit eaeface
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.TypeStrategy;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import org.apache.flink.table.types.utils.TypeConversions;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;

/** Type strategy specific for avoiding nulls. */
/**
* Type strategy specific for avoiding nulls. <br>
* If arg0 is non-nullable, output datatype is exactly the datatype of arg0. Otherwise, output
* datatype is the common type of arg0 and arg1. In the second case, output type is nullable only if
* both args are nullable.
*/
@Internal
class IfNullTypeStrategy implements TypeStrategy {

Expand All @@ -35,10 +43,16 @@ public Optional<DataType> inferType(CallContext callContext) {
final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
final DataType inputDataType = argumentDataTypes.get(0);
final DataType nullReplacementDataType = argumentDataTypes.get(1);

if (!inputDataType.getLogicalType().isNullable()) {
return Optional.of(inputDataType);
}

return Optional.of(nullReplacementDataType);
return LogicalTypeMerging.findCommonType(
Arrays.asList(
inputDataType.getLogicalType(),
nullReplacementDataType.getLogicalType()))
.map(t -> t.copy(nullReplacementDataType.getLogicalType().isNullable()))
.map(TypeConversions::fromLogicalToDataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ Stream<TestSetSpec> getTestSetSpecs() {
DataTypes.STRING())
.testSqlResult("TYPEOF(NULL)", "NULL", DataTypes.STRING()),
TestSetSpec.forFunction(BuiltInFunctionDefinitions.IF_NULL)
.onFieldsWithData(null, new BigDecimal("123.45"))
.andDataTypes(DataTypes.INT().nullable(), DataTypes.DECIMAL(5, 2).notNull())
.onFieldsWithData(null, new BigDecimal("123.45"), "Hello world")
.andDataTypes(
DataTypes.INT().nullable(),
DataTypes.DECIMAL(5, 2).notNull(),
DataTypes.STRING())
.withFunction(TakesNotNull.class)
.testResult(
$("f0").ifNull($("f0")),
Expand All @@ -81,6 +84,11 @@ Stream<TestSetSpec> getTestSetSpecs() {
"IFNULL(f1, f0)",
new BigDecimal("123.45"),
DataTypes.DECIMAL(12, 2).notNull())
.testResult(
$("f2").ifNull("0"),
"IFNULL(f2, '0')",
"Hello world",
DataTypes.STRING().notNull())
.testResult(
call("TakesNotNull", $("f0").ifNull(12)),
"TakesNotNull(IFNULL(f0, 12))",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2786,10 +2786,10 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
val url = "CAST('http://user:pass@host' AS VARCHAR(50))"
val base64 = "CAST('aGVsbG8gd29ybGQ=' AS VARCHAR(20))"

testSqlApi(s"IFNULL(SUBSTR($str1, 2, 3), $str2)", "el")
testSqlApi(s"IFNULL(SUBSTRING($str1, 2, 3), $str2)", "el")
testSqlApi(s"IFNULL(LEFT($str1, 3), $str2)", "He")
testSqlApi(s"IFNULL(RIGHT($str1, 3), $str2)", "ll")
testSqlApi(s"IFNULL(SUBSTR($str1, 2, 3), $str2)", "ell")
testSqlApi(s"IFNULL(SUBSTRING($str1, 2, 3), $str2)", "ell")
testSqlApi(s"IFNULL(LEFT($str1, 3), $str2)", "Hel")
testSqlApi(s"IFNULL(RIGHT($str1, 3), $str2)", "llo")
testSqlApi(s"IFNULL(REGEXP_EXTRACT($str1, 'H(.+?)l(.+?)$$', 2), $str2)", "lo")
testSqlApi(s"IFNULL(REGEXP_REPLACE($str1, 'e.l', 'EXL'), $str2)", "HEXLo")
testSqlApi(s"IFNULL(UPPER($str1), $str2)", "HELLO")
Expand All @@ -2799,9 +2799,9 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
testSqlApi(s"IFNULL(LPAD($str1, 7, $str3), $str2)", "heHello")
testSqlApi(s"IFNULL(RPAD($str1, 7, $str3), $str2)", "Hellohe")
testSqlApi(s"IFNULL(REPEAT($str1, 2), $str2)", "HelloHello")
testSqlApi(s"IFNULL(REVERSE($str1), $str2)", "ol")
testSqlApi(s"IFNULL(REVERSE($str1), $str2)", "olleH")
testSqlApi(s"IFNULL(REPLACE($str3, ' ', '_'), $str2)", "hello_world")
testSqlApi(s"IFNULL(SPLIT_INDEX($str3, ' ', 1), $str2)", "wo")
testSqlApi(s"IFNULL(SPLIT_INDEX($str3, ' ', 1), $str2)", "world")
testSqlApi(s"IFNULL(MD5($str1), $str2)", "8b1a9953c4611296a827abf8c47804d7")
testSqlApi(s"IFNULL(SHA1($str1), $str2)", "f7ff9e8b7bb2e09b70935a5d785e0cc5d9d0abf0")
testSqlApi(
Expand All @@ -2822,7 +2822,7 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
testSqlApi(
s"IFNULL(SHA2($str1, 256), $str2)",
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969")
testSqlApi(s"IFNULL(PARSE_URL($url, 'HOST'), $str2)", "ho")
testSqlApi(s"IFNULL(PARSE_URL($url, 'HOST'), $str2)", "host")
testSqlApi(s"IFNULL(FROM_BASE64($base64), $str2)", "hello world")
testSqlApi(s"IFNULL(TO_BASE64($str3), $str2)", "aGVsbG8gd29ybGQ=")
testSqlApi(s"IFNULL(CHR(65), $str2)", "A")
Expand All @@ -2834,7 +2834,7 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
testSqlApi(s"IFNULL(RTRIM($str4), $str2)", " hello")
testSqlApi(s"IFNULL($str1 || $str2, $str2)", "HelloHi")
testSqlApi(s"IFNULL(SUBSTRING(UUID(), 9, 1), $str2)", "-")
testSqlApi(s"IFNULL(DECODE(ENCODE($str1, 'utf-8'), 'utf-8'), $str2)", "He")
testSqlApi(s"IFNULL(DECODE(ENCODE($str1, 'utf-8'), 'utf-8'), $str2)", "Hello")

testSqlApi(s"IFNULL(CAST(DATE '2021-04-06' AS VARCHAR(10)), $str2)", "2021-04-06")
testSqlApi(s"IFNULL(CAST(TIME '11:05:30' AS VARCHAR(8)), $str2)", "11:05:30")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2333,4 +2333,10 @@ class CalcITCase extends BatchTestBase {
Seq(row(2.0), row(2.0), row(2.0))
)
}

@Test
def testIfNull(): Unit = {
// reported in FLINK-35832
checkResult("SELECT IFNULL(JSON_VALUE('{\"a\":16}','$.a'),'0')", Seq(row("16")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.flink.table.api.config.ExecutionConfigOptions
import org.apache.flink.table.api.config.ExecutionConfigOptions.LegacyCastBehaviour
import org.apache.flink.table.api.internal.TableEnvironmentInternal
import org.apache.flink.table.catalog.CatalogDatabaseImpl
import org.apache.flink.table.data.{GenericRowData, MapData, RowData}
import org.apache.flink.table.data.{GenericRowData, MapData}
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.runtime.utils._
import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
Expand Down Expand Up @@ -810,4 +810,16 @@ class CalcITCase extends StreamingTestBase {
val expected = List("2.0", "2.0", "2.0")
assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
}

@Test
def testIfNull(): Unit = {
// reported in FLINK-35832
val result = tEnv.sqlQuery("SELECT IFNULL(JSON_VALUE('{\"a\":16}','$.a'),'0')")
var sink = new TestingAppendSink
tEnv.toDataStream(result, DataTypes.ROW(DataTypes.STRING())).addSink(sink)
env.execute()

val expected = List("16")
assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
}
}

0 comments on commit eaeface

Please sign in to comment.