# Expanding symmetric functions in the Schur basis
# of S/I for arXiv:1910.00207v1.

n = 9
k = 5

R = PolynomialRing(QQ, ['a%s'%p for p in range(1, k+1)])
# The base ring, called `\mathbb{k}` in the paper.
# We are taking a polynomial ring in `k` variables
# here, which we will use as the `a_1, a_2, ..., a_k`.
Sym = SymmetricFunctions(R) # The `R`
s = Sym.s() # The Schur basis.
h = Sym.h() # The h-basis.
e = Sym.e() # The e-basis.

a = [(-1)**(n-k-1) * ai for ai in R.gens()]
# The vector `(a_1, a_2, ..., a_k)`.

omega = Partition([n-k] * k) # The partition `\omega`.

from itertools import permutations, combinations

def in_box(lam):
    # Check if partition ``lam`` fits is contained in
    # the rectangle `\omega`.
    if len(lam) == 0:
        return True
    if len(lam) > k:
        return False
    if lam[0] > n-k:
        return False
    return True

hei = [h(e[i]) for i in range(k+1)]

@cached_function
def reduce_h(g):
    # One-step reduction of the complete homogeneous symmetric
    # function `h_g` modulo the ideal generated by the
    # `e_i` for `i > k` and the `h_{n-k+i} - a_i` for
    # `0 < i \leq k`.
    if g < 0:
        return h.zero()
    if g <= n-k:
        return h[g]
    if g <= n:
        return a[g - n + k - 1]
    # use Viete:
    return h.sum([(-1)**(i-1) * reduce_h(g-i) * hei[i] for i in range(1, k+1)])

# Fill the cache.
[reduce_h(g) for g in range(2*n - k)]

@cached_function
def sred(lam):
    # Reduction of the Schur function `s_{\text{lam}}`
    # modulo the ideal generated by the
    # `e_i` for `i > k` and the `h_{n-k+i} - a_i` for
    # `0 < i \leq k`.
    # The implementation of this function is mutually
    # recursive with :func:`reduce`.
    l = len(lam)
    if l > k:
        return s.zero()
    if in_box(lam):
        return s[lam]
    JTM = [[reduce_h(lam[i] - i + j) for j in range(l)]
                                     for i in range(l)]
          # This is the Jacobi-Trudi matrix for s[lam],
          # with appropriate entries 1-step-reduced.
    d = Matrix(h, JTM).det()
    # Now, d is a 1-step reduction.
    return reduce(d)

def reduce(f):
    # Reduction of the symmetric function `f`
    # modulo the ideal generated by the
    # `e_i` for `i > k` and the `h_{n-k+i} - a_i` for
    # `0 < i \leq k`.
    return s.sum(c * sred(lam) for lam, c in s(f))

def complement(nu):
    # The complement of the partition `\nu` (= ``nu``)
    # in the rectangle `\omega`.
    # This is commonly denoted `\nu^\vee`.
    # It is the partition obtained by writing `\nu` as
    # a `k`-tuple, then subtracting every entry of this
    # `k`-tuple from `n-k`, and finally reversing the
    # order of these entries.
    if not in_box(nu):
        raise ValueError("nu is not subset of omega")
    nuli = nu[:]
    g = len(nuli)
    nuli.extend([0]*(k-g))
    return Partition([n-k-f for f in reversed(nuli)])

def C(lam, mu, nu):
    # The coefficient `g_{\lambda, \mu, \nu}`, where
    # `\lambda, \mu, \nu` are ``lambda, mu, nu'',
    # respectively.
    lam = Partition(lam)
    mu = Partition(mu)
    nu = Partition(nu)
    red = reduce(s[lam]*s[mu])
    return s(red).coefficient(complement(nu))

def check_S3symm(lam, mu, nu):
    lam = Partition(lam)
    mu = Partition(mu)
    nu = Partition(nu)
    sprod = s(reduce(s[lam]*s[mu]*s[nu]))
    clm = C(lam, mu, nu)
    print(clm)
    return clm == sprod.coefficient(omega)

@cached_function
def hred(lam):
    # Reduction of the complete homogeneous function
    # `h_{\text{lam}}` modulo the ideal generated by the
    # `e_i` for `i > k` and the `h_{n-k+i} - a_i` for
    # `0 < i \leq k`.
    lamm = Partition(lam)
    return reduce(h[lamm])

def shook(j, i):
    par = Partition([n-k-j+1] + [1]*(i-1))
    return s[par]

# testing positivity

# List of partitions that are contained in `\omega`:
allpars = [lam for g in range(k*(n-k) + 1)
               for lam in Partitions(g, max_length=k, max_part=n-k)]

def poly_is_pos(f):
    # check if polynomial f in the b's is positive.
    for m in f.coefficients():
        if m < 0:
            return False
    return True

def check_prod_pos(lam, mu):
    for nu, f in reduce(s[lam]*s[mu]):
        if not poly_is_pos((-1)**(sum(lam)+sum(mu)-sum(nu)) * f):
            return False
    return True

def check_all_pos():
    for lam in allpars:
        print("checking lam = " + str(lam))
        for mu in allpars:
            if mu < lam:
                continue
            if not check_prod_pos(lam, mu):
                print("fail at mu = " + str(mu))
                return False
    return True

