Blog

Python 3.10 introduces Pattern Matching

10 Aug, 2021
Xebia Background Header Wave

The other day I asked on LinkedIn whether people were excited about pattern matching coming to Python 3.10.

A third of the respondents didn’t know what pattern matching is, so what a perfect occasion to write a blog post about it!

As with all constructs, pattern matching is not a necessity for programming languages. Python reached version 3.10 before getting it. But it’s a feature that allows us to write code more clearly and often concisely.

And if you enjoy functional programming languages — such as Erlang, Haskell, etc — you probably missed pattern matching when writing Python.

I did, so I am excited for this feature becoming available in Python 3.10 (currently in beta 4).

In Python, pattern matching will work like this

match statement:
    case condition_1:
        do_something()
    case condition_2:
        do_something_else()
    case _:
        do_something_else_as_nothing_matched()

Python will check if statement matches condition_1, then condition<em>2, and if nothing matches, it will match </em> (a sort of catch all).

The condition can be pretty smart, for example let’s say statement is (3, 5) and condition_1 is (a, b). We can write something like

match (3, 5):
    case (a, ):
        print(a)  # this does **not** match as (a, ) is a single element tuple
                  # while (3, 5) is a two elements tuple
    case (a, b):
        print(b ** 2)  # this will print 25 as (a, b) is a two element tubple
                       # just like (3, 5)

You can see already how powerful this can be!

To see it in action, I dug up some code containing a couple of if/elif/else statements.

The original code is a bit larger and contains extra functionality which is out of scope for now, so I isolated the interesting bits:

from typing import Any, Dict, Iterable, Hashable, Optional

def extract_fields_from_records(
    records: Iterable[Dict], 
    fields: set[Hashable],
    missing: str,
    _default: Dict[Hashable, Hashable]=dict()
    ) -> Iterable[Dict]:
    """
    Returns a generator of dictionaries whose keys are present in <code>fields, 
    starting from an iterable of dictionaries.

    :param records: A iterable of dictionaries
    :param fields: Fields to include
    :param missing: How to handle missing fields. If omit, missing fields are
    simply omitted. If fill, missing fields are added with the default value
     _default. If raise, a KeyError is raised if any values are missing.
    :param _default: When missing="fill" look up by key for the value to 
    fill.
    """

    if missing == "omit":
        _records = (
            {k: v for k, v in rec.items() if k in fields} for rec in records
            )
    elif missing == "fill":
        _records = (
            {k: rec.get(k, _default.get(k, None)) for k in fields} 
            for rec in records
            )
    elif missing == "raise":
        _records = ({k: rec[k] for k in fields} for rec in records)
    else:
        raise ValueError(
            "Unknown value for missing. Valid values are"
            " 'omit', 'fill' and 'raise'."
        )
    return _records

Usage is simple

records = [
    {"age": 25, "height": 1.9},
    {"age": 45, "height": 1.6, "country": "NL"},
]

list(
    extract_fields_from_records(
        records, 
        {"age", "country"},
        missing="fill",
        _default={"country": "World"}
    )
)

resulting in


[
    {'country': 'World', 'age': 25},
    {'country': 'NL', 'age': 45}
]

So how would this function look like with pattern matching? Turns out, it’s pretty simple


def _extract_fields_from_records(
    records: Iterable[Dict], 
    fields: set[Hashable], 
    missing: str,
    _default: Dict[Hashable, Hashable]=dict()
    ) -> Iterable[Dict]:

    match missing:
        case "omit":
            fields = set(fields)
            _records = (
                {k: v for k, v in rec.items() if k in fields} 
                for rec in records
                )
        case "fill":
            _records = (
                {k: rec.get(k, _default.get(k, None)) for k in fields} 
                for rec in records
                )        
        case "raise":
                _records = ({k: rec[k] for k in fields} for rec in records)
        case _:
            raise ValueError(
                "Unknown value for <code>missing. Valid values are"
                " 'omit', 'fill' and 'raise'."
            )

    return _records

This example is simple though, as it’s very "linear" and all I’m doing is saving me some typing by avoiding missing == every time. It was easy to express with if/else.

However, pattern matching becomes powerful when we ask Python to decompose the pattern we want to match.

Let’s assume we’re getting data from somewhere and we need to shape it properly (How often does it happen, doesn’t it!)

The data coming in can have the following form

{
    "generic_key": [1, 2, 3], 
    # other stuff
}

or

{
    "generic_key": 1,   # integer and not a list
    # other stuff
}

or

{
    "log_me": "message",
}

In this last case we should not process the data and just log the message.

In the first two cases instead, we need to do something with the elements of the list or the single integer. And the output should always be a list, because we don’t want the consumers of our function to write more logic to handle the list/integer split. How to write such a function without pattern matching?

from typing import Optional

def transform_dictionary(dct: dict) -> Optional[list[int]]:
    message = dct.get("log_me")
    values = dct.get("generic_key", [])
    if message:
        print(message)  # this should be a log statement!!
    elif isinstance(values, list):
        return [value ** 2 for value in values]  # this is our transformation, insert your own
    elif isinstance(values, int):
        return [values ** 2]
    else:
        ValueError(f"Input %{dct} is not of the required shape")

transform_dictionary({"log_me": "error"})
print(transform_dictionary({"generic_key": [1, 2, 3]}))
print(transform_dictionary({"generic_key": 1}))
> error
> [1, 4, 9]
> [1]

The above works fine, and it’s how you probably write Python today, but do you notice the mental overhead in trying to understand the data structure you’re getting? Without my explanation above, it’s hard. Pattern matching makes it easier to understand!

def _transform_dictionary(dct: dict) -> Optional[list[int]]:
    match dct:
        case {"generic_key": list(values)}:
            return [value ** 2 for value in values]
        case {"generic_key": int(value)}:
            return [value ** 2]
        case {"log_me": message}:
            print(message)
        case _:
            ValueError(f"Input %{dct} is not of the required shape")

_transform_dictionary({"log_me": "error"})
print(_transform_dictionary({"generic_key": [1, 2, 3]}))
print(_transform_dictionary({"generic_key": 1}))
> error
> [1, 4, 9]
> [1]

Now, just by reading the code, I understand the formats dct could have, no documentation needed. We also see some powerful features we didn’t see before. By writing list(values) and int(value), Python binds — respectively — [1, 4, 9] to values and 1 to value.

There is of course much more to it (for example guards): in case you’re curious PEP 634 offers the full spec, and PEP 636 presents a tutorial.

That’s it, thank you for reading. Follow me on Twitter @gglanzani for more good stuff!

Questions?

Get in touch with us to learn more about the subject and related solutions

Explore related posts