1
2
3
4
5
6
7
8
9
10
11 """ a set of functions for interacting with databases
12
13 When possible, it's probably preferable to use a _DbConnection.DbConnect_ object
14
15 """
16 from __future__ import print_function
17
18 import sys
19
20 from rdkit.Dbase import DbInfo
21 from rdkit.Dbase import DbModule
22 from rdkit.Dbase.DbResultSet import DbResultSet, RandomAccessDbResultSet
23 from rdkit.six import string_types, StringIO
24 from rdkit.six.moves import xrange
25
26
28 """ Given a list fromL, returns an iterator of the elements specified using their
29 indices in the list what """
30 return map(lambda x, y=fromL: y[x], what)
31
32
33 -def GetColumns(dBase, table, fieldString, user='sysdba', password='masterkey', join='', cn=None):
34 """ gets a set of data from a table
35
36 **Arguments**
37
38 - dBase: database name
39
40 - table: table name
41
42 - fieldString: a string with the names of the fields to be extracted,
43 this should be a comma delimited list
44
45 - user and password:
46
47 - join: a join clause (omit the verb 'join')
48
49
50 **Returns**
51
52 - a list of the data
53
54 """
55 if not cn:
56 cn = DbModule.connect(dBase, user, password)
57 c = cn.cursor()
58 cmd = 'select %s from %s' % (fieldString, table)
59 if join:
60 if join.strip().find('join') != 0:
61 join = 'join %s' % (join)
62 cmd += ' ' + join
63 c.execute(cmd)
64 return c.fetchall()
65
66
67 -def GetData(dBase, table, fieldString='*', whereString='', user='sysdba', password='masterkey',
68 removeDups=-1, join='', forceList=0, transform=None, randomAccess=1, extras=None,
69 cn=None):
70 """ a more flexible method to get a set of data from a table
71
72 **Arguments**
73
74 - fields: a string with the names of the fields to be extracted,
75 this should be a comma delimited list
76
77 - where: the SQL where clause to be used with the DB query
78
79 - removeDups indicates the column which should be used to screen
80 out duplicates. Only the first appearance of a duplicate will
81 be left in the dataset.
82
83 **Returns**
84
85 - a list of the data
86
87
88 **Notes**
89
90 - EFF: this isn't particularly efficient
91
92 """
93 if forceList and (transform is not None):
94 raise ValueError('forceList and transform arguments are not compatible')
95 if forceList and (not randomAccess):
96 raise ValueError('when forceList is set, randomAccess must also be used')
97 if removeDups > -1:
98 forceList = True
99
100 if not cn:
101 cn = DbModule.connect(dBase, user, password)
102 c = cn.cursor()
103 cmd = 'select %s from %s' % (fieldString, table)
104 if join:
105 if join.strip().find('join') != 0:
106 join = 'join %s' % (join)
107 cmd += ' ' + join
108 if whereString:
109 if whereString.strip().find('where') != 0:
110 whereString = 'where %s' % (whereString)
111 cmd += ' ' + whereString
112
113 if forceList:
114 try:
115 if not extras:
116 c.execute(cmd)
117 else:
118 c.execute(cmd, extras)
119 except Exception:
120 sys.stderr.write('the command "%s" generated errors:\n' % (cmd))
121 import traceback
122 traceback.print_exc()
123 return None
124 if transform is not None:
125 raise ValueError('forceList and transform arguments are not compatible')
126 if not randomAccess:
127 raise ValueError('when forceList is set, randomAccess must also be used')
128 data = c.fetchall()
129 if removeDups >= 0:
130 seen = set()
131 for entry in data[:]:
132 if entry[removeDups] in seen:
133 data.remove(entry)
134 else:
135 seen.add(entry[removeDups])
136 else:
137 if randomAccess:
138 klass = RandomAccessDbResultSet
139 else:
140 klass = DbResultSet
141
142 data = klass(c, cn, cmd, removeDups=removeDups, transform=transform, extras=extras)
143
144 return data
145
146
147 -def DatabaseToText(dBase, table, fields='*', join='', where='', user='sysdba', password='masterkey',
148 delim=',', cn=None):
149 """ Pulls the contents of a database and makes a deliminted text file from them
150
151 **Arguments**
152 - dBase: the name of the DB file to be used
153
154 - table: the name of the table to query
155
156 - fields: the fields to select with the SQL query
157
158 - join: the join clause of the SQL query
159 (e.g. 'join foo on foo.bar=base.bar')
160
161 - where: the where clause of the SQL query
162 (e.g. 'where foo = 2' or 'where bar > 17.6')
163
164 - user: the username for DB access
165
166 - password: the password to be used for DB access
167
168 **Returns**
169
170 - the CSV data (as text)
171
172 """
173 if len(where) and where.strip().find('where') == -1:
174 where = 'where %s' % (where)
175 if len(join) and join.strip().find('join') == -1:
176 join = 'join %s' % (join)
177 sqlCommand = 'select %s from %s %s %s' % (fields, table, join, where)
178 if not cn:
179 cn = DbModule.connect(dBase, user, password)
180 c = cn.cursor()
181 c.execute(sqlCommand)
182 headers = []
183 colsToTake = []
184
185
186 for i in range(len(c.description)):
187 item = c.description[i]
188 if item[1] not in DbInfo.sqlBinTypes:
189 colsToTake.append(i)
190 headers.append(item[0])
191
192 lines = []
193 lines.append(delim.join(headers))
194
195
196 results = c.fetchall()
197 for res in results:
198 d = _take(res, colsToTake)
199 lines.append(delim.join(map(str, d)))
200
201 return '\n'.join(lines)
202
203
204 -def TypeFinder(data, nRows, nCols, nullMarker=None):
205 """
206
207 finds the types of the columns in _data_
208
209 if nullMarker is not None, elements of the data table which are
210 equal to nullMarker will not count towards setting the type of
211 their columns.
212
213 """
214 priorities = {float: 3, int: 2, str: 1, -1: -1}
215 res = [None] * nCols
216 for col in xrange(nCols):
217 typeHere = [-1, 1]
218 for row in xrange(nRows):
219 d = data[row][col]
220 if d is None:
221 continue
222 locType = type(d)
223 if locType != float and locType != int:
224 locType = str
225 try:
226 d = str(d)
227 except UnicodeError as msg:
228 print('cannot convert text from row %d col %d to a string' % (row + 2, col))
229 print('\t>%s' % (repr(d)))
230 raise UnicodeError(msg)
231 else:
232 typeHere[1] = max(typeHere[1], len(str(d)))
233 if isinstance(d, string_types):
234 if nullMarker is None or d != nullMarker:
235 l = max(len(d), typeHere[1])
236 typeHere = [str, l]
237 else:
238 try:
239 fD = float(int(d))
240 except OverflowError:
241 locType = float
242 else:
243 if fD == d:
244 locType = int
245 if not isinstance(typeHere[0], string_types) and \
246 priorities[locType] > priorities[typeHere[0]]:
247 typeHere[0] = locType
248 res[col] = typeHere
249 return res
250
251
253 """ *For Internal Use*
254
255 removes illegal characters from column headings
256 and truncates those which are too long.
257
258 """
259 for i in xrange(len(colHeadings)):
260
261 colHeadings[i] = colHeadings[i].strip()
262 colHeadings[i] = colHeadings[i].replace(' ', '_')
263 colHeadings[i] = colHeadings[i].replace('-', '_')
264 colHeadings[i] = colHeadings[i].replace('.', '_')
265
266 if len(colHeadings[i]) > maxColLabelLen:
267
268 newHead = colHeadings[i].replace('_', '')
269 newHead = newHead[:maxColLabelLen]
270 print('\tHeading %s too long, changed to %s' % (colHeadings[i], newHead))
271 colHeadings[i] = newHead
272 return colHeadings
273
274
276 """ returns a list of SQL type strings
277 """
278 typeStrs = []
279 for i in xrange(len(colTypes)):
280 typ = colTypes[i]
281 if typ[0] == float:
282 typeStrs.append('%s double precision' % colHeadings[i])
283 elif typ[0] == int:
284 typeStrs.append('%s integer' % colHeadings[i])
285 else:
286 typeStrs.append('%s varchar(%d)' % (colHeadings[i], typ[1]))
287 if colHeadings[i] == keyCol:
288 typeStrs[-1] = '%s not null primary key' % (typeStrs[-1])
289 return typeStrs
290
291
293 try:
294 conn.cursor().executemany(sqlStr, block)
295 except Exception:
296 res = 0
297 conn.commit()
298 for row in block:
299 try:
300 conn.cursor().execute(sqlStr, tuple(row))
301 res += 1
302 except Exception:
303 if not silent:
304 import traceback
305 traceback.print_exc()
306 print('insert failed:', sqlStr)
307 print('\t', repr(row))
308 else:
309 conn.commit()
310 else:
311 res = len(block)
312 return res
313
314
315 -def _AddDataToDb(dBase, table, user, password, colDefs, colTypes, data, nullMarker=None,
316 blockSize=100, cn=None):
317 """ *For Internal Use*
318
319 (drops and) creates a table and then inserts the values
320
321 """
322 if not cn:
323 cn = DbModule.connect(dBase, user, password)
324 c = cn.cursor()
325 try:
326 c.execute('drop table %s' % (table))
327 except Exception:
328 print('cannot drop table %s' % (table))
329 try:
330 sqlStr = 'create table %s (%s)' % (table, colDefs)
331 c.execute(sqlStr)
332 except Exception:
333 print('create table failed: ', sqlStr)
334 print('here is the exception:')
335 import traceback
336 traceback.print_exc()
337 return
338 cn.commit()
339 c = None
340
341 block = []
342 entryTxt = [DbModule.placeHolder] * len(data[0])
343 dStr = ','.join(entryTxt)
344 sqlStr = 'insert into %s values (%s)' % (table, dStr)
345 nDone = 0
346 for row in data:
347 entries = [None] * len(row)
348 for col in xrange(len(row)):
349 if row[col] is not None and \
350 (nullMarker is None or row[col] != nullMarker):
351 if colTypes[col][0] == float:
352 entries[col] = float(row[col])
353 elif colTypes[col][0] == int:
354 entries[col] = int(row[col])
355 else:
356 entries[col] = str(row[col])
357 else:
358 entries[col] = None
359 block.append(tuple(entries))
360 if len(block) >= blockSize:
361 nDone += _insertBlock(cn, sqlStr, block)
362 if not hasattr(cn, 'autocommit') or not cn.autocommit:
363 cn.commit()
364 block = []
365 if len(block):
366 nDone += _insertBlock(cn, sqlStr, block)
367 if not hasattr(cn, 'autocommit') or not cn.autocommit:
368 cn.commit()
369
370
371 -def TextFileToDatabase(dBase, table, inF, delim=',', user='sysdba', password='masterkey',
372 maxColLabelLen=31, keyCol=None, nullMarker=None):
373 """loads the contents of the text file into a database.
374
375 **Arguments**
376
377 - dBase: the name of the DB to use
378
379 - table: the name of the table to create/overwrite
380
381 - inF: the file like object from which the data should
382 be pulled (must support readline())
383
384 - delim: the delimiter used to separate fields
385
386 - user: the user name to use in connecting to the DB
387
388 - password: the password to use in connecting to the DB
389
390 - maxColLabelLen: the maximum length a column label should be
391 allowed to have (truncation otherwise)
392
393 - keyCol: the column to be used as an index for the db
394
395 **Notes**
396
397 - if _table_ already exists, it is destroyed before we write
398 the new data
399
400 - we assume that the first row of the file contains the column names
401
402 """
403 table.replace('-', '_')
404 table.replace(' ', '_')
405
406 colHeadings = inF.readline().split(delim)
407 _AdjustColHeadings(colHeadings, maxColLabelLen)
408 nCols = len(colHeadings)
409 data = []
410 inL = inF.readline()
411 while inL:
412 inL = inL.replace('\r', '')
413 inL = inL.replace('\n', '')
414 splitL = inL.split(delim)
415 if len(splitL) != nCols:
416 print('>>>', repr(inL))
417 assert len(splitL) == nCols, 'unequal length'
418 tmpVect = []
419 for entry in splitL:
420 try:
421 val = int(entry)
422 except ValueError:
423 try:
424 val = float(entry)
425 except ValueError:
426 val = entry
427 tmpVect.append(val)
428 data.append(tmpVect)
429 inL = inF.readline()
430 nRows = len(data)
431
432
433 colTypes = TypeFinder(data, nRows, nCols, nullMarker=nullMarker)
434 typeStrs = GetTypeStrings(colHeadings, colTypes, keyCol=keyCol)
435 colDefs = ','.join(typeStrs)
436
437 _AddDataToDb(dBase, table, user, password, colDefs, colTypes, data, nullMarker=nullMarker)
438
439
440 -def DatabaseToDatabase(fromDb, fromTbl, toDb, toTbl, fields='*', join='', where='', user='sysdba',
441 password='masterkey', keyCol=None, nullMarker='None'):
442 """
443
444 FIX: at the moment this is a hack
445
446 """
447 sio = StringIO()
448 sio.write(
449 DatabaseToText(fromDb, fromTbl, fields=fields, join=join, where=where, user=user,
450 password=password))
451 sio.seek(0)
452 TextFileToDatabase(toDb, toTbl, sio, user=user, password=password, keyCol=keyCol,
453 nullMarker=nullMarker)
454
455
456 if __name__ == '__main__':
457 sio = StringIO()
458 sio.write('foo,bar,baz\n')
459 sio.write('1,2,3\n')
460 sio.write('1.1,4,5\n')
461 sio.write('4,foo,6\n')
462 sio.seek(0)
463 from rdkit import RDConfig
464 import os
465 dirLoc = os.path.join(RDConfig.RDCodeDir, 'Dbase', 'TEST.GDB')
466
467 TextFileToDatabase(dirLoc, 'fromtext', sio)
468