276 lines
6.9 KiB
Python
276 lines
6.9 KiB
Python
#
|
|
# Copyright (C) 2005 Fredrik Kuivinen
|
|
#
|
|
|
|
import sys, re, os, traceback
|
|
from sets import Set
|
|
|
|
def die(*args):
|
|
printList(args, sys.stderr)
|
|
sys.exit(2)
|
|
|
|
def printList(list, file=sys.stdout):
|
|
for x in list:
|
|
file.write(str(x))
|
|
file.write(' ')
|
|
file.write('\n')
|
|
|
|
import subprocess
|
|
|
|
# Debugging machinery
|
|
# -------------------
|
|
|
|
DEBUG = 0
|
|
functionsToDebug = Set()
|
|
|
|
def addDebug(func):
|
|
if type(func) == str:
|
|
functionsToDebug.add(func)
|
|
else:
|
|
functionsToDebug.add(func.func_name)
|
|
|
|
def debug(*args):
|
|
if DEBUG:
|
|
funcName = traceback.extract_stack()[-2][2]
|
|
if funcName in functionsToDebug:
|
|
printList(args)
|
|
|
|
# Program execution
|
|
# -----------------
|
|
|
|
class ProgramError(Exception):
|
|
def __init__(self, progStr, error):
|
|
self.progStr = progStr
|
|
self.error = error
|
|
|
|
def __str__(self):
|
|
return self.progStr + ': ' + self.error
|
|
|
|
addDebug('runProgram')
|
|
def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
|
|
debug('runProgram prog:', str(prog), 'input:', str(input))
|
|
if type(prog) is str:
|
|
progStr = prog
|
|
else:
|
|
progStr = ' '.join(prog)
|
|
|
|
try:
|
|
if pipeOutput:
|
|
stderr = subprocess.STDOUT
|
|
stdout = subprocess.PIPE
|
|
else:
|
|
stderr = None
|
|
stdout = None
|
|
pop = subprocess.Popen(prog,
|
|
shell = type(prog) is str,
|
|
stderr=stderr,
|
|
stdout=stdout,
|
|
stdin=subprocess.PIPE,
|
|
env=env)
|
|
except OSError, e:
|
|
debug('strerror:', e.strerror)
|
|
raise ProgramError(progStr, e.strerror)
|
|
|
|
if input != None:
|
|
pop.stdin.write(input)
|
|
pop.stdin.close()
|
|
|
|
if pipeOutput:
|
|
out = pop.stdout.read()
|
|
else:
|
|
out = ''
|
|
|
|
code = pop.wait()
|
|
if returnCode:
|
|
ret = [out, code]
|
|
else:
|
|
ret = out
|
|
if code != 0 and not returnCode:
|
|
debug('error output:', out)
|
|
debug('prog:', prog)
|
|
raise ProgramError(progStr, out)
|
|
# debug('output:', out.replace('\0', '\n'))
|
|
return ret
|
|
|
|
# Code for computing common ancestors
|
|
# -----------------------------------
|
|
|
|
currentId = 0
|
|
def getUniqueId():
|
|
global currentId
|
|
currentId += 1
|
|
return currentId
|
|
|
|
# The 'virtual' commit objects have SHAs which are integers
|
|
shaRE = re.compile('^[0-9a-f]{40}$')
|
|
def isSha(obj):
|
|
return (type(obj) is str and bool(shaRE.match(obj))) or \
|
|
(type(obj) is int and obj >= 1)
|
|
|
|
class Commit(object):
|
|
__slots__ = ['parents', 'firstLineMsg', 'children', '_tree', 'sha',
|
|
'virtual']
|
|
|
|
def __init__(self, sha, parents, tree=None):
|
|
self.parents = parents
|
|
self.firstLineMsg = None
|
|
self.children = []
|
|
|
|
if tree:
|
|
tree = tree.rstrip()
|
|
assert(isSha(tree))
|
|
self._tree = tree
|
|
|
|
if not sha:
|
|
self.sha = getUniqueId()
|
|
self.virtual = True
|
|
self.firstLineMsg = 'virtual commit'
|
|
assert(isSha(tree))
|
|
else:
|
|
self.virtual = False
|
|
self.sha = sha.rstrip()
|
|
assert(isSha(self.sha))
|
|
|
|
def tree(self):
|
|
self.getInfo()
|
|
assert(self._tree != None)
|
|
return self._tree
|
|
|
|
def shortInfo(self):
|
|
self.getInfo()
|
|
return str(self.sha) + ' ' + self.firstLineMsg
|
|
|
|
def __str__(self):
|
|
return self.shortInfo()
|
|
|
|
def getInfo(self):
|
|
if self.virtual or self.firstLineMsg != None:
|
|
return
|
|
else:
|
|
info = runProgram(['git-cat-file', 'commit', self.sha])
|
|
info = info.split('\n')
|
|
msg = False
|
|
for l in info:
|
|
if msg:
|
|
self.firstLineMsg = l
|
|
break
|
|
else:
|
|
if l.startswith('tree'):
|
|
self._tree = l[5:].rstrip()
|
|
elif l == '':
|
|
msg = True
|
|
|
|
class Graph:
|
|
def __init__(self):
|
|
self.commits = []
|
|
self.shaMap = {}
|
|
|
|
def addNode(self, node):
|
|
assert(isinstance(node, Commit))
|
|
self.shaMap[node.sha] = node
|
|
self.commits.append(node)
|
|
for p in node.parents:
|
|
p.children.append(node)
|
|
return node
|
|
|
|
def reachableNodes(self, n1, n2):
|
|
res = {}
|
|
def traverse(n):
|
|
res[n] = True
|
|
for p in n.parents:
|
|
traverse(p)
|
|
|
|
traverse(n1)
|
|
traverse(n2)
|
|
return res
|
|
|
|
def fixParents(self, node):
|
|
for x in range(0, len(node.parents)):
|
|
node.parents[x] = self.shaMap[node.parents[x]]
|
|
|
|
# addDebug('buildGraph')
|
|
def buildGraph(heads):
|
|
debug('buildGraph heads:', heads)
|
|
for h in heads:
|
|
assert(isSha(h))
|
|
|
|
g = Graph()
|
|
|
|
out = runProgram(['git-rev-list', '--parents'] + heads)
|
|
for l in out.split('\n'):
|
|
if l == '':
|
|
continue
|
|
shas = l.split(' ')
|
|
|
|
# This is a hack, we temporarily use the 'parents' attribute
|
|
# to contain a list of SHA1:s. They are later replaced by proper
|
|
# Commit objects.
|
|
c = Commit(shas[0], shas[1:])
|
|
|
|
g.commits.append(c)
|
|
g.shaMap[c.sha] = c
|
|
|
|
for c in g.commits:
|
|
g.fixParents(c)
|
|
|
|
for c in g.commits:
|
|
for p in c.parents:
|
|
p.children.append(c)
|
|
return g
|
|
|
|
# Write the empty tree to the object database and return its SHA1
|
|
def writeEmptyTree():
|
|
tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
|
|
def delTmpIndex():
|
|
try:
|
|
os.unlink(tmpIndex)
|
|
except OSError:
|
|
pass
|
|
delTmpIndex()
|
|
newEnv = os.environ.copy()
|
|
newEnv['GIT_INDEX_FILE'] = tmpIndex
|
|
res = runProgram(['git-write-tree'], env=newEnv).rstrip()
|
|
delTmpIndex()
|
|
return res
|
|
|
|
def addCommonRoot(graph):
|
|
roots = []
|
|
for c in graph.commits:
|
|
if len(c.parents) == 0:
|
|
roots.append(c)
|
|
|
|
superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
|
|
graph.addNode(superRoot)
|
|
for r in roots:
|
|
r.parents = [superRoot]
|
|
superRoot.children = roots
|
|
return superRoot
|
|
|
|
def getCommonAncestors(graph, commit1, commit2):
|
|
'''Find the common ancestors for commit1 and commit2'''
|
|
assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
|
|
|
|
def traverse(start, set):
|
|
stack = [start]
|
|
while len(stack) > 0:
|
|
el = stack.pop()
|
|
set.add(el)
|
|
for p in el.parents:
|
|
if p not in set:
|
|
stack.append(p)
|
|
h1Set = Set()
|
|
h2Set = Set()
|
|
traverse(commit1, h1Set)
|
|
traverse(commit2, h2Set)
|
|
shared = h1Set.intersection(h2Set)
|
|
|
|
if len(shared) == 0:
|
|
shared = [addCommonRoot(graph)]
|
|
|
|
res = Set()
|
|
|
|
for s in shared:
|
|
if len([c for c in s.children if c in shared]) == 0:
|
|
res.add(s)
|
|
return list(res)
|