Implementing Python’s `cmp_to_key` function

Sorting functions in programming languages often take a parameter which allows the user to control how comparison is done when sorting. The most direct way to do this is to have the parameter be a comparison function, which takes two arguments and returns a value indicating how the first argument compares to the second argument.

For example, in Python, an implementation of insertion sort that allows the user to control how comparison is done might look like this:

import operator
from typing import Callable

def isort[T: Comparable](
    ls: list[T],
    cmp: Callable[[T, T], bool]=operator.lt
) -> None:

    for i, x in enumerate(ls[1:], start=1):
        j = i
        while j > 0 and not cmp(ls[j - 1], x):
            ls[j] = ls[j - 1]
            j -= 1
        ls[j] = x

However, Python’s built-in sorting functions work in a different way. Instead of taking a comparison function, they take a key function, which is a function of a single argument that may return anything. This function is used to generate a “key” for each object in the list to be sorted. The sorting is then based on comparison of those keys, rather than the underlying objects themselves, using the built-in comparison operator <. So to write a comparison function which is equivalent to the key function, we just apply the key function to the two arguments, and then compare the results using <. Here’s an implementation of a version of isort which uses a key function:

from typing import Protocol, Self

class Comparable(Protocol):
    def __lt__(self: Self, other: Self) -> bool:
        ...

def key_to_cmp[T, U: Comparable](
    key: Callable[[T], U]
) -> Callable[[T, T], bool]:

    return lambda x, y: key(x) < key(y)

def isort_by_key[T, U: Comparable](
    ls: list[T],
    key: Callable[[T], U]=lambda x: x
) -> None:

    return isort(ls, key_to_cmp(key))

Now, it’s easy to convert a key function to a comparison function: the key_to_cmp function here shows us to do it. But what about the reverse? How do we convert a comparison function to a key function? At least to me, it wasn’t totally obvious how to do this.

However, the solution is pretty simple once you realize that the return value from a key function can be absolutely anything, and that Python allows you to overload the < operator. So to do the conversion, we just need to define a new class, with < overloaded so that it uses the given comparison function when applied to instances of the class. The constructor for this class will then work as the function that does the conversion.

from dataclasses import dataclass

def cmp_to_key[T](
    cmp: Callable[[T, T], bool]
) -> Callable[[T], Comparable]:
    
    @dataclass(slots=True)
    class Key[T]:
        value: T

        def __lt__(self: Self, other: Self) -> bool:
            return cmp(self.value, other.value)

    return Key

Python actually provides a cmp_to_key function in the functools module which does just this.

The return type of this cmp_to_key function is an interesting one. Even though we know that the return class will be an instance of the locally-defined Key class, we haven’t indicated this in the return type. Indeed, we can’t do so because the Key class is defined within the scope of the cmp_to_key function, but the type declaration for cmp_to_key can only make use of names defined in the outer scope. And in fact, this is OK, because we shouldn’t be able to be any more specific about the return type. The Key class is an implementation detail, which callers of the cmp_to_key function have no need to be aware of. All callers need to know is that the function returned by cmp_to_key turns values of type T into comparable values, which is just what is conveyed by the Callable[[T], Comparable] type.

This way of typing cmp_to_key relies on subtyping—namely the fact that Comparable is a supertype of all the types Comparable[T] where T is another type. In a type system without subtyping, I think it might be possible to express the type using an existential type—something like this, in pseudo-Haskell syntax:

cmpToKey :: (a -> a -> Bool) -> exists b. Ord b => (a -> b)

This isn’t actually a legal Haskell type signature, because Haskell doesn’t allow you to express existential types directly; exists isn’t actually a Haskell keyword, though forall is. You’re supposed to express existential types by taking advantage of certain equivalences which allow you to express exists in terms of forall. I can’t figure out any way by which the type above could be expressed using forall, although given my level of fluency with this stuff, this is not a strong indicator that it’s impossible. Even if it is possible to express the type of cmpToKey in Haskell, I doubt it’s possible to actually implement it, because it would require a declaration of a type class instance within the function—something like this, which is definitely not valid Haskell syntax:

cmpToKey f =
    newtype Key a = Key a

    instance Ord (Key a) where
        Key x <= Key y = f x y
    
    Key

There’s probably something to be said here about the difference between type classes and classes in object-oriented programming, and the advantages and disadvantages of these two different approaches to polymorphism, but I don’t feel like I understand the subject well enough to say any more.

Leave a comment