Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 128 additions & 2 deletions permuta/perm_sets/permset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
from itertools import islice
from typing import ClassVar, Dict, Iterable, List, NamedTuple, Optional, Union
from itertools import combinations, islice
from typing import ClassVar, Dict, Iterable, Iterator, List, NamedTuple, Optional, Union

from ..patterns import MeshPatt, Perm
from ..permutils import is_finite, is_insertion_encodable, is_polynomial
Expand Down Expand Up @@ -202,6 +202,132 @@ def _all(self) -> Iterable[Perm]:
yield from gen
length += 1

def right_juxtaposition(self, other: "Av") -> "Av":
"""Compute the basis of the juxtaposition of two permutation classes.

Given self = Av(B1) and other = Av(B2), returns the permutation class
E = Av(B) where E consists of all permutations that can be written as
the juxtaposition of a permutation from self on the left and a
permutation from other on the right.

Raises NotImplementedError: If either basis is a MeshBasis.
"""
if not isinstance(self.basis, Basis) or not isinstance(other.basis, Basis):
raise NotImplementedError(Av._BASIS_ONLY_MSG)

candidates: List[Perm] = []

for b1 in self.basis:
for b2 in other.basis:
# |σ| = 0 case: no overlap
candidates.extend(self._sigma_0_candidates(b1, b2))
# |σ| = 1 case: one element overlap
candidates.extend(self._sigma_1_candidates(b1, b2))

# Basis constructor automatically minimizes
return Av(Basis(*candidates))

def above_juxtaposition(self, other: "Av") -> "Av":
"""Compute the basis of the above juxtaposition of two permutation classes.

Given self = Av(B1) and other = Av(B2), returns the permutation class
where self is on the bottom and other is on top.

This is computed by taking inverses, computing right_juxtaposition,
then inverting the result.

Raises NotImplementedError: If either basis is a MeshBasis.
"""
if not isinstance(self.basis, Basis) or not isinstance(other.basis, Basis):
raise NotImplementedError(Av._BASIS_ONLY_MSG)

# Compute inverse classes
self_inverse = Av(Basis(*[p.inverse() for p in self.basis]))
other_inverse = Av(Basis(*[p.inverse() for p in other.basis]))

# Compute right juxtaposition of inverses
result_inverse = self_inverse.right_juxtaposition(other_inverse)

# Return inverse of result
return Av(Basis(*[p.inverse() for p in result_inverse.basis]))

@staticmethod
def _sigma_0_candidates(b1: Perm, b2: Perm) -> Iterator[Perm]:
"""Generate candidates where left and right patterns don't overlap.

Generates all permutations of length |b1| + |b2| where the first |b1|
positions have pattern b1 and the last |b2| positions have pattern b2.
"""
n1, n2 = len(b1), len(b2)
total = n1 + n2

# Choose which values go to the left block
for left_values in combinations(range(total), n1):
right_values = [v for v in range(total) if v not in left_values]

# Build the permutation
result = [0] * total
# Left positions get values according to pattern b1
for pos in range(n1):
result[pos] = left_values[b1[pos]]
# Right positions get values according to pattern b2
for pos in range(n2):
result[n1 + pos] = right_values[b2[pos]]

yield Perm(result)

@staticmethod
def _sigma_1_candidates( # pylint: disable=R0914
b1: Perm, b2: Perm
) -> Iterator[Perm]:
"""Generate candidates where left and right patterns overlap by one element.

Generates all permutations of length |b1| + |b2| - 1 where the first |b1|
positions have pattern b1 and the last |b2| positions have pattern b2,
with position |b1| - 1 shared between both patterns.
"""
n1, n2 = len(b1), len(b2)
total = n1 + n2 - 1

# The shared position is at index n1 - 1
# Its value v must satisfy: v = b1[-1] + b2[0]
# (it must be at rank b1[-1] among left and rank b2[0] among right values)
v = b1[-1] + b2[0]

# Values less than v: {0, ..., v-1}
# Values greater than v: {v+1, ..., total-1}
values_below = list(range(v))
values_above = list(range(v + 1, total))

# Left block needs b1[-1] values below v, right block gets the rest
k1 = b1[-1] # number of values < v in left block

# Iterate over all ways to partition values below v
for left_below in combinations(values_below, k1):
right_below = [x for x in values_below if x not in left_below]

# Iterate over all ways to partition values above v
for left_above in combinations(values_above, n1 - 1 - k1):
right_above = [x for x in values_above if x not in left_above]

# Build the left and right value sets
left_values = sorted(list(left_below) + [v] + list(left_above))
right_values = sorted(right_below + [v] + right_above)

# Build the permutation
result = [0] * total

# Left positions (0 to n1-1) get values according to pattern b1
for pos in range(n1):
result[pos] = left_values[b1[pos]]

# Right positions (n1-1 to total-1) get values according to pattern b2
# But position n1-1 is already set, so we only set n1 to total-1
for pos in range(1, n2):
result[n1 - 1 + pos] = right_values[b2[pos]]

yield Perm(result)

def __str__(self) -> str:
return f"Av({','.join(str(p) for p in self.basis)})"

Expand Down
157 changes: 157 additions & 0 deletions tests/perm_sets/test_av.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,160 @@ def test_invalid_ops_with_mesh_patt():
Av(MeshBasis(Perm((0, 1)))).is_insertion_encodable()
with pytest.raises(NotImplementedError):
Av(MeshBasis(Perm((0, 1)))).is_polynomial()


# Tests for right_juxtaposition


def test_right_juxtaposition_basic():
"""Test Av(21) | Av(12) = Av(213, 312)."""
av_21 = Av(Basis(Perm((1, 0))))
av_12 = Av(Basis(Perm((0, 1))))
result = av_21.right_juxtaposition(av_12)
expected_basis = {Perm((1, 0, 2)), Perm((2, 0, 1))}
assert set(result.basis) == expected_basis


def test_right_juxtaposition_same_class():
"""Test Av(21) | Av(21) gives expected basis."""
av_21 = Av(Basis(Perm((1, 0))))
result = av_21.right_juxtaposition(av_21)
# Basis should be {321, 2143, 2431} = {(2,1,0), (1,0,3,2), (1,3,2,0)}
expected_basis = {Perm((2, 1, 0)), Perm((1, 0, 3, 2)), Perm((2, 0, 3, 1))}
assert set(result.basis) == expected_basis


def test_right_juxtaposition_enumeration():
"""Test that juxtaposition class has correct enumeration."""
av_21 = Av(Basis(Perm((1, 0))))
av_12 = Av(Basis(Perm((0, 1))))
result = av_21.right_juxtaposition(av_12)
# [Av(21)|Av(12)] = permutations that can be split into decreasing|increasing
# Enumeration: 1, 1, 2, 4, 8, 16, 32 (powers of 2 starting at n=2)
assert result.enumeration(6) == [1, 1, 2, 4, 8, 16, 32]


def test_right_juxtaposition_multiple_basis_elements():
"""Test juxtaposition with multiple basis elements."""
# Only contains empty and singleton permutations
av_21_12 = Av(Basis(Perm((1, 0)), Perm((0, 1))))
av_132 = Av(Basis(Perm((0, 2, 1))))
result = av_21_12.right_juxtaposition(av_132)
# The result should be a valid Av object with a minimized basis
assert isinstance(result.basis, Basis)
assert len(result.basis) > 0


def test_right_juxtaposition_longer_patterns():
"""Test juxtaposition with longer patterns."""
av_132 = Av(Basis(Perm((0, 2, 1))))
av_231 = Av(Basis(Perm((1, 2, 0))))
result = av_132.right_juxtaposition(av_231)
# Verify result is valid and has expected structure
assert isinstance(result.basis, Basis)
# All basis elements should have length between 3 and 6 (|b1|+|b2|-1 to |b1|+|b2|)
for perm in result.basis:
assert 5 <= len(perm) <= 6


def test_right_juxtaposition_mesh_basis_raises():
"""Test that juxtaposition with MeshBasis raises NotImplementedError."""
av_classical = Av(Basis(Perm((1, 0))))
av_mesh = Av(MeshBasis(Perm((0, 1))))
with pytest.raises(NotImplementedError):
av_classical.right_juxtaposition(av_mesh)
with pytest.raises(NotImplementedError):
av_mesh.right_juxtaposition(av_classical)


def test_right_juxtaposition_containment():
"""Test that permutations in the juxtaposition class can be split correctly."""
av_21 = Av(Basis(Perm((1, 0))))
av_12 = Av(Basis(Perm((0, 1))))
result = av_21.right_juxtaposition(av_12)

# Check some permutations that should be in the class
# 21 can be split as (2)|(1) where (2) is decreasing and (1) is increasing
assert Perm((1, 0)) in result
# 12 can be split as ()|(12) where () is trivially decreasing and (12) is increasing
assert Perm((0, 1)) in result
# 1 is trivially in the class
assert Perm((0,)) in result

# Check some permutations that should NOT be in the class
# 213 = (1,0,2) is a basis element, so not in the class
assert Perm((1, 0, 2)) not in result
# 312 = (2,0,1) is a basis element, so not in the class
assert Perm((2, 0, 1)) not in result


# Tests for above_juxtaposition


def test_above_juxtaposition_basic():
"""Test basic above juxtaposition with Av(21) below and Av(12) above."""
av_21 = Av(Basis(Perm((1, 0))))
av_12 = Av(Basis(Perm((0, 1))))
result = av_21.above_juxtaposition(av_12)
# Result should be valid Av with Basis
assert isinstance(result.basis, Basis)
assert len(result.basis) > 0


def test_above_juxtaposition_inverse_relationship():
"""Test that above_juxtaposition relates to right_juxtaposition via inverses."""
av_21 = Av(Basis(Perm((1, 0))))
av_132 = Av(Basis(Perm((0, 2, 1))))

# Compute above juxtaposition directly
above_result = av_21.above_juxtaposition(av_132)

# Compute via inverses manually
av_21_inv = Av(Basis(*[p.inverse() for p in av_21.basis]))
av_132_inv = Av(Basis(*[p.inverse() for p in av_132.basis]))
right_result = av_21_inv.right_juxtaposition(av_132_inv)
manual_result = Av(Basis(*[p.inverse() for p in right_result.basis]))

# The bases should be equivalent
assert set(above_result.basis) == set(manual_result.basis)


def test_above_juxtaposition_enumeration():
"""Test that above juxtaposition class has expected enumeration."""
av_21 = Av(Basis(Perm((1, 0))))
av_12 = Av(Basis(Perm((0, 1))))
result = av_21.above_juxtaposition(av_12)
# Permutations that can be split by value: lower values decreasing, upper increasing
# This should give 2^(n-1) for n >= 1
assert result.enumeration(6) == [1, 1, 2, 4, 8, 16, 32]


def test_above_juxtaposition_same_class():
"""Test above juxtaposition with the same class."""
av_21 = Av(Basis(Perm((1, 0))))
result = av_21.above_juxtaposition(av_21)
# Should be valid and have a non-empty basis
assert isinstance(result.basis, Basis)
assert len(result.basis) > 0


def test_above_juxtaposition_longer_patterns():
"""Test above juxtaposition with longer patterns."""
av_132 = Av(Basis(Perm((0, 2, 1))))
av_231 = Av(Basis(Perm((1, 2, 0))))
result = av_132.above_juxtaposition(av_231)
# Verify result is valid
assert isinstance(result.basis, Basis)
# All basis elements should have length between 5 and 6
for perm in result.basis:
assert 5 <= len(perm) <= 6


def test_above_juxtaposition_mesh_basis_raises():
"""Test that above_juxtaposition with MeshBasis raises NotImplementedError."""
av_classical = Av(Basis(Perm((1, 0))))
av_mesh = Av(MeshBasis(Perm((0, 1))))
with pytest.raises(NotImplementedError):
av_classical.above_juxtaposition(av_mesh)
with pytest.raises(NotImplementedError):
av_mesh.above_juxtaposition(av_classical)