surj
surj

Reputation: 4904

Get name / alias of column in PySpark

I am defining a column object like this:

column = F.col('foo').alias('bar')

I know I can get the full expression using str(column), but how can I get the column's alias only?

In the example, I'm looking for a function get_column_name where get_column_name(column) returns the string bar.

Upvotes: 12

Views: 6081

Answers (4)

ZygD
ZygD

Reputation: 24386

I have used and reused this many times in several systems:

def get_col_name(col):
    if isinstance(col, str):
        return col
    if str(col)[-3] != '`':
        return str(col).split("'")[-2].split(" AS ")[-1]
    return str(col).replace('``', '`').split(" AS `")[-1].split("`'")[-2]

It covers cases like:

  • several aliases
  • aliases containing several words
  • intentional backticks in aliases

Testing:

from pyspark.sql import functions as F

cols = [
    'foo',
    F.col('foo'),
    F.col('foo').alias('bar'),
    F.col('foo').alias('bar').alias('baz'),
    F.col('foo').alias('foo bar'),
    F.col('foo').alias('foo AS bar'),
    F.col('foo').alias('foo AS bar').alias('bar AS baz'),
    F.col('foo').alias('foo AS bar').alias('baz'),
    F.col('foo').alias('foo AS bar').alias('````bar AS baz````'),
    F.concat(F.lit('bar'), F.lit('baz')),
]

for c in cols:
    print(get_col_name(c))
# foo
# foo
# bar
# baz
# foo bar
# foo AS bar
# bar AS baz
# baz
# ````bar AS baz````
# concat(bar, baz)

Upvotes: 1

Brendan
Brendan

Reputation: 2075

Regex is not needed. For PySpark 3.x it looks like backticks were replaced with quotes, so this might not work out of the box on earlier spark versions, but should be easy enough to modify.

Note: requires Python 3.9+

from pyspark.sql import Column

def get_column_name(col: Column) -> str:
    """
    PySpark doesn't allow you to directly access the column name with respect to aliases
    from an unbound column. We have to parse this out from the string representation.

    This works on columns with one or more aliases as well as unaliased columns.

    Returns:
        Col name as str, with respect to aliasing
    """
    c = str(col).removeprefix("Column<'").removesuffix("'>")
    return c.split(' AS ')[-1]

Some tests to validate behavior:

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark() -> SparkSession:
    # Provide a session spark fixture for all tests
    yield SparkSession.builder.getOrCreate()

def test_get_col_name(spark):
    col = f.col('a')
    actual = get_column_name(col)
    assert actual == 'a'


def test_get_col_name_alias(spark):
    col = f.col('a').alias('b')
    actual = get_column_name(col)
    assert actual == 'b'


def test_get_col_name_multiple_alias(spark):
    col = f.col('a').alias('b').alias('c')
    actual = get_column_name(col)
    assert actual == 'c'


def test_get_col_name_longer(spark: SparkSession):
    """Added this test due to identifying a bug in the old implementation (if you use lstrip/rstrip, this will fail)"""
    col = f.col("local")
    actual = get_column_name(col)
    assert actual == "local"

Upvotes: 1

ebonnal
ebonnal

Reputation: 1167

Alternatively, we could use a wrapper function to tweak the behavior of Column.alias and Column.name methods to store the alias only in an AS attribute:

from pyspark.sql import Column, SparkSession
from pyspark.sql.functions import col, explode, array, struct, lit
SparkSession.builder.getOrCreate()

def alias_wrapper(self, *alias, **kwargs):
    renamed_col = Column._alias(self, *alias, **kwargs)
    renamed_col.AS = alias[0] if len(alias) == 1 else alias
    return renamed_col

Column._alias, Column.alias, Column.name, Column.AS = Column.alias, alias_wrapper, alias_wrapper, None

which then guarantees:

assert(col("foo").alias("bar").AS == "bar")
# `name` should act like `alias`
assert(col("foo").name("bar").AS == "bar")
# column without alias should have None in `AS`
assert(col("foo").AS is None)
# multialias should be handled
assert(explode(array(struct(lit(1), lit("a")))).alias("foo", "bar").AS == ("foo", "bar"))

Upvotes: 2

pault
pault

Reputation: 43504

One way is through regular expressions:

from pyspark.sql.functions import col
column = col('foo').alias('bar')
print(column)
#Column<foo AS `bar`>

import re
print(re.findall("(?<=AS `)\w+(?=`>$)", str(column)))[0]
#'bar'

Upvotes: 5

Related Questions