import pari
from pari import GEN, set_prec

PariError = "PARI error"
pari_debug = 0

# Module constants.
gzero = GEN(0)
INTEGER = 1
REAL = 2
INTEGERMOD = 3
RATNUM = 4
COMPLEX = 6
PADIC = 7
QUADRATIC = 8
POLYMOD = 9
POLYNOMIAL = 10
SERIES = 11
RATFUN = 13
QUADFORM = 16
VECTOR = 17 # A column vector.
TRANSVECTOR = 18 # A row vector.
MATRIX = 19

# Class gen = wrapper class around builtin type GEN.

def gen2GEN (x) :
	if type(x) in (type([]), type((0,))) :
		return map(gen2GEN, x)
	try : 
		if x.__class__ == gen : 
			return x.value
		else : raise PariError, 'could not convert to GEN'
	except :
		return x
		
class gen :
	def __init__(self, value = None) :
		if type(value) == type(self) and value.__class__ == gen :
			self.value = value.value
		elif value == None :
			self.value = None
		elif type(value) in (type(0), type(0.0), type(gzero), type('')) :
			self.value = GEN(value)
		elif type(value) in (type([]), type((0,)) ) :
			self.value = GEN(gen2GEN(value))
		else : 
			raise PariError, "wrong initialiser type"
	def __cmp__(self, other) :
		if pari_debug :
			print 'comparing', self, ':', type(self), 'and', other, ':', type(other)
		if type(other) != type(self) or other.__class__ != gen :
			other = gen(other)
		res = cmp(self.value, other.value)
		if pari_debug :
			print 'res =', res
		return res
# Apparently not used.
#	def __nonzero__(self) :
#		print 'nonzero'
#		res = cmp(self.value, gzero) 
#		print res
#		print res != 0
#		return res != 0
	def __getattr__(self, name) :
		return eval('self.value.'+name)
	def __str__(self) : 
		return str(self.value)
	def __repr__(self) : 
		return '<gen instance, value = ' + str(self.value) + ' >'
	# Useful functions.
	def bin_op(self, other, op) :
		if type(other) != type(self) or other.__class__ != gen :
			other = gen(other)
		res = gen()
		if op != 'pow' :
			res.value = eval('self.value' + op + 'other.value')
		else :
			res.value = eval('pow(self.value, other.value)')
		return res
	def rev_bin_op(self, other, op) :
		if type(other) != type(self) or other.__class__ != gen :
			other = gen(other)
		res = gen()
		if op != 'pow' :
			res.value = eval('other.value' + op + 'self.value')
		else :
			res.value = eval('pow(other.value, self.value)')
		return res
	# Numeric operations.
	def __add__(self, other) :
		return self.bin_op(other, '+')
	def __radd__(self, other) :
		return self.bin_op(other, '+')
	def __sub__(self, other) :
		return self.bin_op(other, '-')
	def __rsub__(self, other) :
		return self.rev_bin_op(other, '-')
	def __mul__(self, other) :
		return self.bin_op(other, '*')
	def __rmul__(self, other) :
		return self.bin_op(other, '*')
	def __div__(self, other) :
		return self.bin_op(other, '/')
	def __rdiv__(self, other) :
		return self.rev_bin_op(other, '/')
	def __mod__(self, other) :
		return self.bin_op(other, '%')
	def __rmod__(self, other) :
		return self.rev_bin_op(other, '%')
	def __pow__(self, other) :
		return self.bin_op(other, 'pow')
	def __rpow__(self, other) :
		return self.rev_bin_op(other, 'pow')
	def __neg__(self) :
		return gen(-self.value)
	def __pos__(self) :
		return self
	def __abs__(self) :
		return gen(abs(self.value))
	def __nonzero__(self) :
		return self.value != gzero
	def __int__(self) :
		return int(self.value)
	def __float__(self) :
		return float(self.value)
	def __float__(self) :
		return long(self.value)
	# Transforms other into a gen.
	#def __coerce__(self, other) :	
	#	return (self, gen(other))
	# Sequence operations.
	# GEN compound object have a range of [1..lg(x)-1] so we are in
	# trouble...
	def __len__(self) :
		return len(self.value)
	def __getitem__(self, i) :
		return gen(self.value[i])
	def __setitem__(self, i, value) :
		self.value[i] = gen(value).value
	# There is no __list__ method. Convert self to list.
	# Maybe the best way to solve iteration problem (for x in X.list() :...)
	def list(self) :
		res = []
		for i in range(1, self.value.lg) :
			res.append(gen(self.value[i]))
		return res

def equal(x, y) :
	x = gen(x)
	y = gen(y)
	return pari.equal(x.value, y.value)

def matrix(x) :
	x = gen(x)
	return gen(pari.matrix(x.value))
	
# A safier lisexpr : suppress all blanks in the string before calling GP.
def lisexpr(s) :
	# import regsub
	s = my_gsub(' ', '', s)
	return pari.lisexpr(s)

# Eval a string using current name-space. (BOGUS).
import regex
prog1 = regex.compile("\([a-zA-Z][a-zA-Z0-9]*\)\(.*\)")
prog2 = regex.compile("[ \t]*(")

def peval(str) :
	import regex
	import __main__
	idents = {}
	s = str
	# Look for all identifiers that are not function calls.
	while prog1.search(s) != -1 :
		ident = prog1.group(1)
		s = prog1.group(2)
		if prog2.match(s) == -1 : # It is not a function call.
			idents[ident] = None
	for ident in idents.keys() :
		pari.install_var(ident, gen(__main__.__dict__[ident]).value)
	try :
		res = gen(lisexpr(str))
	finally :
		for ident in idents.keys() :
			pari.kill_var(ident)
	return res

# PARI function objects. (Remember that in PARI, function are not 
# really first class objects).
# One should avoid to have variables name 'python_temp<int>' in
# the PARI machine.
class PariFun :
	def __init__(self, name ) :
		self.name = name
	def __call__(self, *args) :
		if pari_debug : print 'calling :', self.name, args
		i = 1
		command = self.name + '('
		for arg in args :
			pari.install_var('pythontemp'+`i`, gen(arg).value)
			command = command + 'pythontemp'+`i` + ','
			i = i + 1
		command = command[:-1] + ')'
		if pari_debug : print command
		try : 
			res = gen(lisexpr(command))
		finally :
			for j in range(1, i) :
				pari.kill_var('pythontemp'+`j`)
		return res

# Read PARI functions table.
import Pari # Import myself to play with my own name-space.

for f in pari.gp_functions() :
	name = f
	# Correct here function names that could clash with python
	# builtin functions.
	if f in ('eval', 'matrix') : 
		continue
	if f == 'type' : 
		name = 'ptype'
	Pari.__dict__[name] = eval('PariFun("' + f + '")')

###############################################################################
# Just in case someone doesn't have regsub.py

import regex

def my_gsub(pat, repl, str):
        prog = regex.compile(pat)
        new = ''
        start = 0
        first = 1
        while prog.search(str, start) >= 0:
                regs = prog.regs
                a, b = regs[0]
                if a == b == start and not first:
                        if start >= len(str) or prog.search(str, start+1) < 0:
                                break
                        regs = prog.regs
                        a, b = regs[0]
                new = new + str[start:a] + expand(repl, regs, str)
                start = b
                first = 0
        new = new + str[start:]
        return new

def expand(repl, regs, str):
  if '\\' not in repl:
    return repl
  new = ''
  i = 0
  while i < len(repl):
    c = repl[i]; i = i+1
    if c <> '\\' or i >= len(repl):
      new = new + c
    else:
      c = repl[i]; i = i+1
      if '0' <= c <= '9':
        a, b = regs[eval(c)]
        new = new + str[a:b]
      elif c == '\\':
        new = new + c
      else:
        new = new + '\\' + c
  return new

