Package rdkit :: Package Dbase :: Module DbUtils
[hide private]
[frames] | no frames]

Source Code for Module rdkit.Dbase.DbUtils

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2000-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """ a set of functions for interacting with databases 
 12   
 13   When possible, it's probably preferable to use a _DbConnection.DbConnect_ object 
 14   
 15  """ 
 16  from __future__ import print_function 
 17   
 18  import sys 
 19   
 20  from rdkit.Dbase import DbInfo 
 21  from rdkit.Dbase import DbModule 
 22  from rdkit.Dbase.DbResultSet import DbResultSet, RandomAccessDbResultSet 
 23  from rdkit.six import string_types, StringIO 
 24  from rdkit.six.moves import xrange 
 25   
 26   
27 -def _take(fromL, what):
28 """ Given a list fromL, returns an iterator of the elements specified using their 29 indices in the list what """ 30 return map(lambda x, y=fromL: y[x], what)
31 32
33 -def GetColumns(dBase, table, fieldString, user='sysdba', password='masterkey', join='', cn=None):
34 """ gets a set of data from a table 35 36 **Arguments** 37 38 - dBase: database name 39 40 - table: table name 41 42 - fieldString: a string with the names of the fields to be extracted, 43 this should be a comma delimited list 44 45 - user and password: 46 47 - join: a join clause (omit the verb 'join') 48 49 50 **Returns** 51 52 - a list of the data 53 54 """ 55 if not cn: 56 cn = DbModule.connect(dBase, user, password) 57 c = cn.cursor() 58 cmd = 'select %s from %s' % (fieldString, table) 59 if join: 60 if join.strip().find('join') != 0: 61 join = 'join %s' % (join) 62 cmd += ' ' + join 63 c.execute(cmd) 64 return c.fetchall()
65 66
67 -def GetData(dBase, table, fieldString='*', whereString='', user='sysdba', password='masterkey', 68 removeDups=-1, join='', forceList=0, transform=None, randomAccess=1, extras=None, 69 cn=None):
70 """ a more flexible method to get a set of data from a table 71 72 **Arguments** 73 74 - fields: a string with the names of the fields to be extracted, 75 this should be a comma delimited list 76 77 - where: the SQL where clause to be used with the DB query 78 79 - removeDups indicates the column which should be used to screen 80 out duplicates. Only the first appearance of a duplicate will 81 be left in the dataset. 82 83 **Returns** 84 85 - a list of the data 86 87 88 **Notes** 89 90 - EFF: this isn't particularly efficient 91 92 """ 93 if forceList and (transform is not None): 94 raise ValueError('forceList and transform arguments are not compatible') 95 if forceList and (not randomAccess): 96 raise ValueError('when forceList is set, randomAccess must also be used') 97 if removeDups > -1: 98 forceList = True 99 100 if not cn: 101 cn = DbModule.connect(dBase, user, password) 102 c = cn.cursor() 103 cmd = 'select %s from %s' % (fieldString, table) 104 if join: 105 if join.strip().find('join') != 0: 106 join = 'join %s' % (join) 107 cmd += ' ' + join 108 if whereString: 109 if whereString.strip().find('where') != 0: 110 whereString = 'where %s' % (whereString) 111 cmd += ' ' + whereString 112 113 if forceList: 114 try: 115 if not extras: 116 c.execute(cmd) 117 else: 118 c.execute(cmd, extras) 119 except Exception: 120 sys.stderr.write('the command "%s" generated errors:\n' % (cmd)) 121 import traceback 122 traceback.print_exc() 123 return None 124 if transform is not None: 125 raise ValueError('forceList and transform arguments are not compatible') 126 if not randomAccess: 127 raise ValueError('when forceList is set, randomAccess must also be used') 128 data = c.fetchall() 129 if removeDups >= 0: 130 seen = set() 131 for entry in data[:]: 132 if entry[removeDups] in seen: 133 data.remove(entry) 134 else: 135 seen.add(entry[removeDups]) 136 else: 137 if randomAccess: 138 klass = RandomAccessDbResultSet 139 else: 140 klass = DbResultSet 141 142 data = klass(c, cn, cmd, removeDups=removeDups, transform=transform, extras=extras) 143 144 return data
145 146
147 -def DatabaseToText(dBase, table, fields='*', join='', where='', user='sysdba', password='masterkey', 148 delim=',', cn=None):
149 """ Pulls the contents of a database and makes a deliminted text file from them 150 151 **Arguments** 152 - dBase: the name of the DB file to be used 153 154 - table: the name of the table to query 155 156 - fields: the fields to select with the SQL query 157 158 - join: the join clause of the SQL query 159 (e.g. 'join foo on foo.bar=base.bar') 160 161 - where: the where clause of the SQL query 162 (e.g. 'where foo = 2' or 'where bar > 17.6') 163 164 - user: the username for DB access 165 166 - password: the password to be used for DB access 167 168 **Returns** 169 170 - the CSV data (as text) 171 172 """ 173 if len(where) and where.strip().find('where') == -1: 174 where = 'where %s' % (where) 175 if len(join) and join.strip().find('join') == -1: 176 join = 'join %s' % (join) 177 sqlCommand = 'select %s from %s %s %s' % (fields, table, join, where) 178 if not cn: 179 cn = DbModule.connect(dBase, user, password) 180 c = cn.cursor() 181 c.execute(sqlCommand) 182 headers = [] 183 colsToTake = [] 184 # the description field of the cursor carries around info about the columns 185 # of the table 186 for i in range(len(c.description)): 187 item = c.description[i] 188 if item[1] not in DbInfo.sqlBinTypes: 189 colsToTake.append(i) 190 headers.append(item[0]) 191 192 lines = [] 193 lines.append(delim.join(headers)) 194 195 # grab the data 196 results = c.fetchall() 197 for res in results: 198 d = _take(res, colsToTake) 199 lines.append(delim.join(map(str, d))) 200 201 return '\n'.join(lines)
202 203
204 -def TypeFinder(data, nRows, nCols, nullMarker=None):
205 """ 206 207 finds the types of the columns in _data_ 208 209 if nullMarker is not None, elements of the data table which are 210 equal to nullMarker will not count towards setting the type of 211 their columns. 212 213 """ 214 priorities = {float: 3, int: 2, str: 1, -1: -1} 215 res = [None] * nCols 216 for col in xrange(nCols): 217 typeHere = [-1, 1] 218 for row in xrange(nRows): 219 d = data[row][col] 220 if d is None: 221 continue 222 locType = type(d) 223 if locType != float and locType != int: 224 locType = str 225 try: 226 d = str(d) 227 except UnicodeError as msg: 228 print('cannot convert text from row %d col %d to a string' % (row + 2, col)) 229 print('\t>%s' % (repr(d))) 230 raise UnicodeError(msg) 231 else: 232 typeHere[1] = max(typeHere[1], len(str(d))) 233 if isinstance(d, string_types): 234 if nullMarker is None or d != nullMarker: 235 l = max(len(d), typeHere[1]) 236 typeHere = [str, l] 237 else: 238 try: 239 fD = float(int(d)) 240 except OverflowError: 241 locType = float 242 else: 243 if fD == d: 244 locType = int 245 if not isinstance(typeHere[0], string_types) and \ 246 priorities[locType] > priorities[typeHere[0]]: 247 typeHere[0] = locType 248 res[col] = typeHere 249 return res
250 251
252 -def _AdjustColHeadings(colHeadings, maxColLabelLen):
253 """ *For Internal Use* 254 255 removes illegal characters from column headings 256 and truncates those which are too long. 257 258 """ 259 for i in xrange(len(colHeadings)): 260 # replace unallowed characters and strip extra white space 261 colHeadings[i] = colHeadings[i].strip() 262 colHeadings[i] = colHeadings[i].replace(' ', '_') 263 colHeadings[i] = colHeadings[i].replace('-', '_') 264 colHeadings[i] = colHeadings[i].replace('.', '_') 265 266 if len(colHeadings[i]) > maxColLabelLen: 267 # interbase (at least) has a limit on the maximum length of a column name 268 newHead = colHeadings[i].replace('_', '') 269 newHead = newHead[:maxColLabelLen] 270 print('\tHeading %s too long, changed to %s' % (colHeadings[i], newHead)) 271 colHeadings[i] = newHead 272 return colHeadings
273 274
275 -def GetTypeStrings(colHeadings, colTypes, keyCol=None):
276 """ returns a list of SQL type strings 277 """ 278 typeStrs = [] 279 for i in xrange(len(colTypes)): 280 typ = colTypes[i] 281 if typ[0] == float: 282 typeStrs.append('%s double precision' % colHeadings[i]) 283 elif typ[0] == int: 284 typeStrs.append('%s integer' % colHeadings[i]) 285 else: 286 typeStrs.append('%s varchar(%d)' % (colHeadings[i], typ[1])) 287 if colHeadings[i] == keyCol: 288 typeStrs[-1] = '%s not null primary key' % (typeStrs[-1]) 289 return typeStrs
290 291
292 -def _insertBlock(conn, sqlStr, block, silent=False):
293 try: 294 conn.cursor().executemany(sqlStr, block) 295 except Exception: 296 res = 0 297 conn.commit() 298 for row in block: 299 try: 300 conn.cursor().execute(sqlStr, tuple(row)) 301 res += 1 302 except Exception: 303 if not silent: 304 import traceback 305 traceback.print_exc() 306 print('insert failed:', sqlStr) 307 print('\t', repr(row)) 308 else: 309 conn.commit() 310 else: 311 res = len(block) 312 return res
313 314
315 -def _AddDataToDb(dBase, table, user, password, colDefs, colTypes, data, nullMarker=None, 316 blockSize=100, cn=None):
317 """ *For Internal Use* 318 319 (drops and) creates a table and then inserts the values 320 321 """ 322 if not cn: 323 cn = DbModule.connect(dBase, user, password) 324 c = cn.cursor() 325 try: 326 c.execute('drop table %s' % (table)) 327 except Exception: 328 print('cannot drop table %s' % (table)) 329 try: 330 sqlStr = 'create table %s (%s)' % (table, colDefs) 331 c.execute(sqlStr) 332 except Exception: 333 print('create table failed: ', sqlStr) 334 print('here is the exception:') 335 import traceback 336 traceback.print_exc() 337 return 338 cn.commit() 339 c = None 340 341 block = [] 342 entryTxt = [DbModule.placeHolder] * len(data[0]) 343 dStr = ','.join(entryTxt) 344 sqlStr = 'insert into %s values (%s)' % (table, dStr) 345 nDone = 0 346 for row in data: 347 entries = [None] * len(row) 348 for col in xrange(len(row)): 349 if row[col] is not None and \ 350 (nullMarker is None or row[col] != nullMarker): 351 if colTypes[col][0] == float: 352 entries[col] = float(row[col]) 353 elif colTypes[col][0] == int: 354 entries[col] = int(row[col]) 355 else: 356 entries[col] = str(row[col]) 357 else: 358 entries[col] = None 359 block.append(tuple(entries)) 360 if len(block) >= blockSize: 361 nDone += _insertBlock(cn, sqlStr, block) 362 if not hasattr(cn, 'autocommit') or not cn.autocommit: 363 cn.commit() 364 block = [] 365 if len(block): 366 nDone += _insertBlock(cn, sqlStr, block) 367 if not hasattr(cn, 'autocommit') or not cn.autocommit: 368 cn.commit()
369 370
371 -def TextFileToDatabase(dBase, table, inF, delim=',', user='sysdba', password='masterkey', 372 maxColLabelLen=31, keyCol=None, nullMarker=None):
373 """loads the contents of the text file into a database. 374 375 **Arguments** 376 377 - dBase: the name of the DB to use 378 379 - table: the name of the table to create/overwrite 380 381 - inF: the file like object from which the data should 382 be pulled (must support readline()) 383 384 - delim: the delimiter used to separate fields 385 386 - user: the user name to use in connecting to the DB 387 388 - password: the password to use in connecting to the DB 389 390 - maxColLabelLen: the maximum length a column label should be 391 allowed to have (truncation otherwise) 392 393 - keyCol: the column to be used as an index for the db 394 395 **Notes** 396 397 - if _table_ already exists, it is destroyed before we write 398 the new data 399 400 - we assume that the first row of the file contains the column names 401 402 """ 403 table.replace('-', '_') 404 table.replace(' ', '_') 405 406 colHeadings = inF.readline().split(delim) 407 _AdjustColHeadings(colHeadings, maxColLabelLen) 408 nCols = len(colHeadings) 409 data = [] 410 inL = inF.readline() 411 while inL: 412 inL = inL.replace('\r', '') 413 inL = inL.replace('\n', '') 414 splitL = inL.split(delim) 415 if len(splitL) != nCols: 416 print('>>>', repr(inL)) 417 assert len(splitL) == nCols, 'unequal length' 418 tmpVect = [] 419 for entry in splitL: 420 try: 421 val = int(entry) 422 except ValueError: 423 try: 424 val = float(entry) 425 except ValueError: 426 val = entry 427 tmpVect.append(val) 428 data.append(tmpVect) 429 inL = inF.readline() 430 nRows = len(data) 431 432 # determine the types of each column 433 colTypes = TypeFinder(data, nRows, nCols, nullMarker=nullMarker) 434 typeStrs = GetTypeStrings(colHeadings, colTypes, keyCol=keyCol) 435 colDefs = ','.join(typeStrs) 436 437 _AddDataToDb(dBase, table, user, password, colDefs, colTypes, data, nullMarker=nullMarker)
438 439
440 -def DatabaseToDatabase(fromDb, fromTbl, toDb, toTbl, fields='*', join='', where='', user='sysdba', 441 password='masterkey', keyCol=None, nullMarker='None'):
442 """ 443 444 FIX: at the moment this is a hack 445 446 """ 447 sio = StringIO() 448 sio.write( 449 DatabaseToText(fromDb, fromTbl, fields=fields, join=join, where=where, user=user, 450 password=password)) 451 sio.seek(0) 452 TextFileToDatabase(toDb, toTbl, sio, user=user, password=password, keyCol=keyCol, 453 nullMarker=nullMarker)
454 455 456 if __name__ == '__main__': # pragma: nocover 457 sio = StringIO() 458 sio.write('foo,bar,baz\n') 459 sio.write('1,2,3\n') 460 sio.write('1.1,4,5\n') 461 sio.write('4,foo,6\n') 462 sio.seek(0) 463 from rdkit import RDConfig 464 import os 465 dirLoc = os.path.join(RDConfig.RDCodeDir, 'Dbase', 'TEST.GDB') 466 467 TextFileToDatabase(dirLoc, 'fromtext', sio) 468