Source code for opt_einsum.parser

#!/usr/bin/env python
# coding: utf-8
"""
A functionally equivalent parser of the numpy.einsum input parser
"""

import numpy as np


einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


def is_valid_einsum_char(x):
    """Check if the character ``x`` is valid for numpy einsum.
    """
    return (x in einsum_symbols_base) or (x in ',->.')


def has_valid_einsum_chars_only(einsum_str):
    """Check if ``einsum_str`` contains only valid characters for numpy einsum.
    """
    return all(map(is_valid_einsum_char, einsum_str))


[docs]def get_symbol(i): """Get the symbol corresponding to int ``i`` - runs through the usual 52 letters before resorting to unicode characters, starting at ``chr(192)``. Examples -------- >>> get_symbol(2) 'c' >>> oe.get_symbol(200) 'Ŕ' >>> oe.get_symbol(20000) '京' """ if i < 52: return einsum_symbols_base[i] return chr(i + 140)
def gen_unused_symbols(used, n): """Generate ``n`` symbols that are not already in ``used``. """ i = cnt = 0 while cnt < n: s = get_symbol(i) i += 1 if s in used: continue yield s cnt += 1 def convert_to_valid_einsum_chars(einsum_str): """Convert the str ``einsum_str`` to contain only the alphabetic characters valid for numpy einsum. """ # partition into valid and invalid sets valid, invalid = set(), set() for x in einsum_str: (valid if is_valid_einsum_char(x) else invalid).add(x) # get replacements for invalid chars that are not already used available = gen_unused_symbols(valid, len(invalid)) # map invalid to available and replace in the inputs replacer = dict(zip(invalid, available)) return "".join(replacer.get(x, x) for x in einsum_str) def find_output_str(subscripts): """Find the output string for the inputs ``susbcripts``. """ tmp_subscripts = subscripts.replace(",", "") return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1) def find_output_shape(inputs, shapes, output): """Find the output shape for given inputs, shapes and output string, taking into account broadcasting. """ return tuple( max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output ) def possibly_convert_to_numpy(x): """Convert things without a 'shape' to ndarrays, but leave everything else. """ if not hasattr(x, 'shape'): return np.asanyarray(x) else: return x def parse_einsum_input(operands): """ A reproduction of einsum c side einsum parsing in python. Returns ------- input_strings : str Parsed input strings output_string : str Parsed output string operands : list of array_like The operands to use in the numpy contraction Examples -------- The operand list is simplified to reduce printing: >>> a = np.random.rand(4, 4) >>> b = np.random.rand(4, 4, 4) >>> parse_einsum_input(('...a,...a->...', a, b)) ('za,xza', 'xz', [a, b]) >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) ('za,xza', 'xz', [a, b]) """ if len(operands) == 0: raise ValueError("No input operands") if isinstance(operands[0], str): subscripts = operands[0].replace(" ", "") operands = [possibly_convert_to_numpy(x) for x in operands[1:]] else: tmp_operands = list(operands) operand_list = [] subscript_list = [] for p in range(len(operands) // 2): operand_list.append(tmp_operands.pop(0)) subscript_list.append(tmp_operands.pop(0)) output_list = tmp_operands[-1] if len(tmp_operands) else None operands = [possibly_convert_to_numpy(x) for x in operand_list] subscripts = "" last = len(subscript_list) - 1 for num, sub in enumerate(subscript_list): for s in sub: if s is Ellipsis: subscripts += "..." elif isinstance(s, int): subscripts += get_symbol(s) else: raise TypeError("For this input type lists must contain either int or Ellipsis") if num != last: subscripts += "," if output_list is not None: subscripts += "->" for s in output_list: if s is Ellipsis: subscripts += "..." elif isinstance(s, int): subscripts += get_symbol(s) else: raise TypeError("For this input type lists must contain either int or Ellipsis") # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) if invalid or (subscripts.count("->") != 1): raise ValueError("Subscripts can only contain one '->'.") # Parse ellipses if "." in subscripts: used = subscripts.replace(".", "").replace(",", "").replace("->", "") ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands))) longest = 0 # Do we have an output to account for? if "->" in subscripts: input_tmp, output_sub = subscripts.split("->") split_subscripts = input_tmp.split(",") out_sub = True else: split_subscripts = subscripts.split(',') out_sub = False for num, sub in enumerate(split_subscripts): if "." in sub: if (sub.count(".") != 3) or (sub.count("...") != 1): raise ValueError("Invalid Ellipses.") # Take into account numerical values if operands[num].shape == (): ellipse_count = 0 else: ellipse_count = max(len(operands[num].shape), 1) - (len(sub) - 3) if ellipse_count > longest: longest = ellipse_count if ellipse_count < 0: raise ValueError("Ellipses lengths do not match.") elif ellipse_count == 0: split_subscripts[num] = sub.replace('...', '') else: split_subscripts[num] = sub.replace('...', ellipse_inds[-ellipse_count:]) subscripts = ",".join(split_subscripts) # Figure out output ellipses if longest == 0: out_ellipse = "" else: out_ellipse = ellipse_inds[-longest:] if out_sub: subscripts += "->" + output_sub.replace("...", out_ellipse) else: # Special care for outputless ellipses output_subscript = find_output_str(subscripts) normal_inds = ''.join(sorted(set(output_subscript) - set(out_ellipse))) subscripts += "->" + out_ellipse + normal_inds # Build output string if does not exist if "->" in subscripts: input_subscripts, output_subscript = subscripts.split("->") else: input_subscripts, output_subscript = subscripts, find_output_str(subscripts) # Make sure output subscripts are in the input for char in output_subscript: if char not in input_subscripts: raise ValueError("Output character %s did not appear in the input" % char) # Make sure number operands is equivalent to the number of terms if len(input_subscripts.split(',')) != len(operands): raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.") return input_subscripts, output_subscript, operands