diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7cd74ecd5..32004656f 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -483,6 +483,8 @@ def __eq__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) @@ -492,6 +494,8 @@ def __ne__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_not_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__ne__(rhs.expr)) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 1cf824a15..d046eb48c 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -153,8 +153,8 @@ def test_relational_expr(test_ctx): batch = pa.RecordBatch.from_arrays( [ - pa.array([1, 2, 3]), - pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), + pa.array([1, 2, 3, None]), + pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()), ], names=["a", "b"], ) @@ -171,6 +171,10 @@ def test_relational_expr(test_ctx): assert df.filter(col("b") != "beta").count() == 2 assert df.filter(col("a") == "beta").count() == 0 + assert df.filter(col("a") == None).count() == 1 # noqa: E711 + assert df.filter(col("a") != None).count() == 3 # noqa: E711 + assert df.filter(col("b") == None).count() == 1 # noqa: E711 + assert df.filter(col("b") != None).count() == 3 # noqa: E711 def test_expr_to_variant():