summaryrefslogtreecommitdiffstats
path: root/google_appengine/google/appengine/tools/bulkloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'google_appengine/google/appengine/tools/bulkloader.py')
-rwxr-xr-xgoogle_appengine/google/appengine/tools/bulkloader.py3827
1 files changed, 3827 insertions, 0 deletions
diff --git a/google_appengine/google/appengine/tools/bulkloader.py b/google_appengine/google/appengine/tools/bulkloader.py
new file mode 100755
index 0000000..e288b00
--- /dev/null
+++ b/google_appengine/google/appengine/tools/bulkloader.py
@@ -0,0 +1,3827 @@
+#!/usr/bin/env python
+#
+# Copyright 2007 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Imports data over HTTP.
+
+Usage:
+ %(arg0)s [flags]
+
+ --debug Show debugging information. (Optional)
+ --app_id=<string> Application ID of endpoint (Optional for
+ *.appspot.com)
+ --auth_domain=<domain> The auth domain to use for logging in and for
+ UserProperties. (Default: gmail.com)
+ --bandwidth_limit=<int> The maximum number of bytes per second for the
+ aggregate transfer of data to the server. Bursts
+ may exceed this, but overall transfer rate is
+ restricted to this rate. (Default 250000)
+ --batch_size=<int> Number of Entity objects to include in each post to
+ the URL endpoint. The more data per row/Entity, the
+ smaller the batch size should be. (Default 10)
+ --config_file=<path> File containing Model and Loader definitions.
+ (Required unless --dump or --restore are used)
+ --db_filename=<path> Specific progress database to write to, or to
+ resume from. If not supplied, then a new database
+ will be started, named:
+ bulkloader-progress-TIMESTAMP.
+ The special filename "skip" may be used to simply
+ skip reading/writing any progress information.
+ --download Export entities to a file.
+ --dry_run Do not execute any remote_api calls.
+ --dump Use zero-configuration dump format.
+ --email=<string> The username to use. Will prompt if omitted.
+ --exporter_opts=<string>
+ A string to pass to the Exporter.initialize method.
+ --filename=<path> Path to the file to import. (Required)
+ --has_header Skip the first row of the input.
+ --http_limit=<int> The maximum numer of HTTP requests per second to
+ send to the server. (Default: 8)
+ --kind=<string> Name of the Entity object kind to put in the
+ datastore. (Required)
+ --loader_opts=<string> A string to pass to the Loader.initialize method.
+ --log_file=<path> File to write bulkloader logs. If not supplied
+ then a new log file will be created, named:
+ bulkloader-log-TIMESTAMP.
+ --map Map an action across datastore entities.
+ --mapper_opts=<string> A string to pass to the Mapper.Initialize method.
+ --num_threads=<int> Number of threads to use for uploading entities
+ (Default 10)
+ --passin Read the login password from stdin.
+ --restore Restore from zero-configuration dump format.
+ --result_db_filename=<path>
+ Result database to write to for downloads.
+ --rps_limit=<int> The maximum number of records per second to
+ transfer to the server. (Default: 20)
+ --url=<string> URL endpoint to post to for importing data.
+ (Required)
+
+The exit status will be 0 on success, non-zero on import failure.
+
+Works with the remote_api mix-in library for google.appengine.ext.remote_api.
+Please look there for documentation about how to setup the server side.
+
+Example:
+
+%(arg0)s --url=http://app.appspot.com/remote_api --kind=Model \
+ --filename=data.csv --config_file=loader_config.py
+
+"""
+
+
+
+import csv
+import errno
+import getopt
+import getpass
+import imp
+import logging
+import os
+import Queue
+import re
+import shutil
+import signal
+import StringIO
+import sys
+import threading
+import time
+import traceback
+import urllib2
+import urlparse
+
+from google.appengine.datastore import entity_pb
+
+from google.appengine.api import apiproxy_stub_map
+from google.appengine.api import datastore
+from google.appengine.api import datastore_errors
+from google.appengine.datastore import datastore_pb
+from google.appengine.ext import db
+from google.appengine.ext import key_range as key_range_module
+from google.appengine.ext.db import polymodel
+from google.appengine.ext.remote_api import remote_api_stub
+from google.appengine.ext.remote_api import throttle as remote_api_throttle
+from google.appengine.runtime import apiproxy_errors
+from google.appengine.tools import adaptive_thread_pool
+from google.appengine.tools import appengine_rpc
+from google.appengine.tools.requeue import ReQueue
+
+try:
+ import sqlite3
+except ImportError:
+ pass
+
+logger = logging.getLogger('google.appengine.tools.bulkloader')
+
+KeyRange = key_range_module.KeyRange
+
+DEFAULT_THREAD_COUNT = 10
+
+DEFAULT_BATCH_SIZE = 10
+
+DEFAULT_DOWNLOAD_BATCH_SIZE = 100
+
+DEFAULT_QUEUE_SIZE = DEFAULT_THREAD_COUNT * 10
+
+_THREAD_SHOULD_EXIT = '_THREAD_SHOULD_EXIT'
+
+STATE_READ = 0
+STATE_SENDING = 1
+STATE_SENT = 2
+STATE_NOT_SENT = 3
+
+STATE_GETTING = 1
+STATE_GOT = 2
+STATE_ERROR = 3
+
+DATA_CONSUMED_TO_HERE = 'DATA_CONSUMED_TO_HERE'
+
+INITIAL_BACKOFF = 1.0
+
+BACKOFF_FACTOR = 2.0
+
+
+DEFAULT_BANDWIDTH_LIMIT = 250000
+
+DEFAULT_RPS_LIMIT = 20
+
+DEFAULT_REQUEST_LIMIT = 8
+
+MAXIMUM_INCREASE_DURATION = 5.0
+MAXIMUM_HOLD_DURATION = 12.0
+
+
+def ImportStateMessage(state):
+ """Converts a numeric state identifier to a status message."""
+ return ({
+ STATE_READ: 'Batch read from file.',
+ STATE_SENDING: 'Sending batch to server.',
+ STATE_SENT: 'Batch successfully sent.',
+ STATE_NOT_SENT: 'Error while sending batch.'
+ }[state])
+
+
+def ExportStateMessage(state):
+ """Converts a numeric state identifier to a status message."""
+ return ({
+ STATE_READ: 'Batch read from file.',
+ STATE_GETTING: 'Fetching batch from server',
+ STATE_GOT: 'Batch successfully fetched.',
+ STATE_ERROR: 'Error while fetching batch'
+ }[state])
+
+
+def MapStateMessage(state):
+ """Converts a numeric state identifier to a status message."""
+ return ({
+ STATE_READ: 'Batch read from file.',
+ STATE_GETTING: 'Querying for batch from server',
+ STATE_GOT: 'Batch successfully fetched.',
+ STATE_ERROR: 'Error while fetching or mapping.'
+ }[state])
+
+
+def ExportStateName(state):
+ """Converts a numeric state identifier to a string."""
+ return ({
+ STATE_READ: 'READ',
+ STATE_GETTING: 'GETTING',
+ STATE_GOT: 'GOT',
+ STATE_ERROR: 'NOT_GOT'
+ }[state])
+
+
+def ImportStateName(state):
+ """Converts a numeric state identifier to a string."""
+ return ({
+ STATE_READ: 'READ',
+ STATE_GETTING: 'SENDING',
+ STATE_GOT: 'SENT',
+ STATE_NOT_SENT: 'NOT_SENT'
+ }[state])
+
+
+class Error(Exception):
+ """Base-class for exceptions in this module."""
+
+
+class MissingPropertyError(Error):
+ """An expected field is missing from an entity, and no default was given."""
+
+
+class FatalServerError(Error):
+ """An unrecoverable error occurred while posting data to the server."""
+
+
+class ResumeError(Error):
+ """Error while trying to resume a partial upload."""
+
+
+class ConfigurationError(Error):
+ """Error in configuration options."""
+
+
+class AuthenticationError(Error):
+ """Error while trying to authenticate with the server."""
+
+
+class FileNotFoundError(Error):
+ """A filename passed in by the user refers to a non-existent input file."""
+
+
+class FileNotReadableError(Error):
+ """A filename passed in by the user refers to a non-readable input file."""
+
+
+class FileExistsError(Error):
+ """A filename passed in by the user refers to an existing output file."""
+
+
+class FileNotWritableError(Error):
+ """A filename passed in by the user refers to a non-writable output file."""
+
+
+class BadStateError(Error):
+ """A work item in an unexpected state was encountered."""
+
+
+class KeyRangeError(Error):
+ """An error during construction of a KeyRangeItem."""
+
+
+class FieldSizeLimitError(Error):
+ """The csv module tried to read a field larger than the size limit."""
+
+ def __init__(self, limit):
+ self.message = """
+A field in your CSV input file has exceeded the current limit of %d.
+
+You can raise this limit by adding the following lines to your config file:
+
+import csv
+csv.field_size_limit(new_limit)
+
+where new_limit is number larger than the size in bytes of the largest
+field in your CSV.
+""" % limit
+ Error.__init__(self, self.message)
+
+
+class NameClashError(Error):
+ """A name clash occurred while trying to alias old method names."""
+
+ def __init__(self, old_name, new_name, klass):
+ Error.__init__(self, old_name, new_name, klass)
+ self.old_name = old_name
+ self.new_name = new_name
+ self.klass = klass
+
+
+def GetCSVGeneratorFactory(kind, csv_filename, batch_size, csv_has_header,
+ openfile=open, create_csv_reader=csv.reader):
+ """Return a factory that creates a CSV-based UploadWorkItem generator.
+
+ Args:
+ kind: The kind of the entities being uploaded.
+ csv_filename: File on disk containing CSV data.
+ batch_size: Maximum number of CSV rows to stash into an UploadWorkItem.
+ csv_has_header: Whether to skip the first row of the CSV.
+ openfile: Used for dependency injection.
+ create_csv_reader: Used for dependency injection.
+
+ Returns:
+ A callable (accepting the Progress Queue and Progress Generators
+ as input) which creates the UploadWorkItem generator.
+ """
+ loader = Loader.RegisteredLoader(kind)
+ loader._Loader__openfile = openfile
+ loader._Loader__create_csv_reader = create_csv_reader
+ record_generator = loader.generate_records(csv_filename)
+
+ def CreateGenerator(request_manager, progress_queue, progress_generator):
+ """Initialize a UploadWorkItem generator.
+
+ Args:
+ request_manager: A RequestManager instance.
+ progress_queue: A ProgressQueue instance to send progress information.
+ progress_generator: A generator of progress information or None.
+
+ Returns:
+ An UploadWorkItemGenerator instance.
+ """
+ return UploadWorkItemGenerator(request_manager,
+ progress_queue,
+ progress_generator,
+ record_generator,
+ csv_has_header,
+ batch_size)
+
+ return CreateGenerator
+
+
+class UploadWorkItemGenerator(object):
+ """Reads rows from a row generator and generates UploadWorkItems."""
+
+ def __init__(self,
+ request_manager,
+ progress_queue,
+ progress_generator,
+ record_generator,
+ skip_first,
+ batch_size):
+ """Initialize a WorkItemGenerator.
+
+ Args:
+ request_manager: A RequestManager instance with which to associate
+ WorkItems.
+ progress_queue: A progress queue with which to associate WorkItems.
+ progress_generator: A generator of progress information.
+ record_generator: A generator of data records.
+ skip_first: Whether to skip the first data record.
+ batch_size: The number of data records per WorkItem.
+ """
+ self.request_manager = request_manager
+ self.progress_queue = progress_queue
+ self.progress_generator = progress_generator
+ self.reader = record_generator
+ self.skip_first = skip_first
+ self.batch_size = batch_size
+ self.line_number = 1
+ self.column_count = None
+ self.read_rows = []
+ self.row_count = 0
+ self.xfer_count = 0
+
+ def _AdvanceTo(self, line):
+ """Advance the reader to the given line.
+
+ Args:
+ line: A line number to advance to.
+ """
+ while self.line_number < line:
+ self.reader.next()
+ self.line_number += 1
+ self.row_count += 1
+ self.xfer_count += 1
+
+ def _ReadRows(self, key_start, key_end):
+ """Attempts to read and encode rows [key_start, key_end].
+
+ The encoded rows are stored in self.read_rows.
+
+ Args:
+ key_start: The starting line number.
+ key_end: The ending line number.
+
+ Raises:
+ StopIteration: if the reader runs out of rows
+ ResumeError: if there are an inconsistent number of columns.
+ """
+ assert self.line_number == key_start
+ self.read_rows = []
+ while self.line_number <= key_end:
+ row = self.reader.next()
+ self.row_count += 1
+ if self.column_count is None:
+ self.column_count = len(row)
+ else:
+ if self.column_count != len(row):
+ raise ResumeError('Column count mismatch, %d: %s' %
+ (self.column_count, str(row)))
+ self.read_rows.append((self.line_number, row))
+ self.line_number += 1
+
+ def _MakeItem(self, key_start, key_end, rows, progress_key=None):
+ """Makes a UploadWorkItem containing the given rows, with the given keys.
+
+ Args:
+ key_start: The start key for the UploadWorkItem.
+ key_end: The end key for the UploadWorkItem.
+ rows: A list of the rows for the UploadWorkItem.
+ progress_key: The progress key for the UploadWorkItem
+
+ Returns:
+ An UploadWorkItem instance for the given batch.
+ """
+ assert rows
+
+ item = UploadWorkItem(self.request_manager, self.progress_queue, rows,
+ key_start, key_end, progress_key=progress_key)
+
+ return item
+
+ def Batches(self):
+ """Reads from the record_generator and generates UploadWorkItems.
+
+ Yields:
+ Instances of class UploadWorkItem
+
+ Raises:
+ ResumeError: If the progress database and data file indicate a different
+ number of rows.
+ """
+ if self.skip_first:
+ logger.info('Skipping header line.')
+ try:
+ self.reader.next()
+ except StopIteration:
+ return
+
+ exhausted = False
+
+ self.line_number = 1
+ self.column_count = None
+
+ logger.info('Starting import; maximum %d entities per post',
+ self.batch_size)
+
+ state = None
+ if self.progress_generator:
+ for progress_key, state, key_start, key_end in self.progress_generator:
+ if key_start:
+ try:
+ self._AdvanceTo(key_start)
+ self._ReadRows(key_start, key_end)
+ yield self._MakeItem(key_start,
+ key_end,
+ self.read_rows,
+ progress_key=progress_key)
+ except StopIteration:
+ logger.error('Mismatch between data file and progress database')
+ raise ResumeError(
+ 'Mismatch between data file and progress database')
+ elif state == DATA_CONSUMED_TO_HERE:
+ try:
+ self._AdvanceTo(key_end + 1)
+ except StopIteration:
+ state = None
+
+ if self.progress_generator is None or state == DATA_CONSUMED_TO_HERE:
+ while not exhausted:
+ key_start = self.line_number
+ key_end = self.line_number + self.batch_size - 1
+ try:
+ self._ReadRows(key_start, key_end)
+ except StopIteration:
+ exhausted = True
+ key_end = self.line_number - 1
+ if key_start <= key_end:
+ yield self._MakeItem(key_start, key_end, self.read_rows)
+
+
+class CSVGenerator(object):
+ """Reads a CSV file and generates data records."""
+
+ def __init__(self,
+ csv_filename,
+ openfile=open,
+ create_csv_reader=csv.reader):
+ """Initializes a CSV generator.
+
+ Args:
+ csv_filename: File on disk containing CSV data.
+ openfile: Used for dependency injection of 'open'.
+ create_csv_reader: Used for dependency injection of 'csv.reader'.
+ """
+ self.csv_filename = csv_filename
+ self.openfile = openfile
+ self.create_csv_reader = create_csv_reader
+
+ def Records(self):
+ """Reads the CSV data file and generates row records.
+
+ Yields:
+ Lists of strings
+
+ Raises:
+ ResumeError: If the progress database and data file indicate a different
+ number of rows.
+ """
+ csv_file = self.openfile(self.csv_filename, 'rb')
+ reader = self.create_csv_reader(csv_file, skipinitialspace=True)
+ try:
+ for record in reader:
+ yield record
+ except csv.Error, e:
+ if e.args and e.args[0].startswith('field larger than field limit'):
+ limit = e.args[1]
+ raise FieldSizeLimitError(limit)
+ else:
+ raise
+
+
+class KeyRangeItemGenerator(object):
+ """Generates ranges of keys to download.
+
+ Reads progress information from the progress database and creates
+ KeyRangeItem objects corresponding to incompletely downloaded parts of an
+ export.
+ """
+
+ def __init__(self, request_manager, kind, progress_queue, progress_generator,
+ key_range_item_factory):
+ """Initialize the KeyRangeItemGenerator.
+
+ Args:
+ request_manager: A RequestManager instance.
+ kind: The kind of entities being transferred.
+ progress_queue: A queue used for tracking progress information.
+ progress_generator: A generator of prior progress information, or None
+ if there is no prior status.
+ key_range_item_factory: A factory to produce KeyRangeItems.
+ """
+ self.request_manager = request_manager
+ self.kind = kind
+ self.row_count = 0
+ self.xfer_count = 0
+ self.progress_queue = progress_queue
+ self.progress_generator = progress_generator
+ self.key_range_item_factory = key_range_item_factory
+
+ def Batches(self):
+ """Iterate through saved progress information.
+
+ Yields:
+ KeyRangeItem instances corresponding to undownloaded key ranges.
+ """
+ if self.progress_generator is not None:
+ for progress_key, state, key_start, key_end in self.progress_generator:
+ if state is not None and state != STATE_GOT and key_start is not None:
+ key_start = ParseKey(key_start)
+ key_end = ParseKey(key_end)
+
+ key_range = KeyRange(key_start=key_start,
+ key_end=key_end)
+
+ result = self.key_range_item_factory(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range,
+ progress_key=progress_key,
+ state=STATE_READ)
+ yield result
+ else:
+ key_range = KeyRange()
+
+ yield self.key_range_item_factory(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range)
+
+
+class DownloadResult(object):
+ """Holds the result of an entity download."""
+
+ def __init__(self, continued, direction, keys, entities):
+ self.continued = continued
+ self.direction = direction
+ self.keys = keys
+ self.entities = entities
+ self.count = len(keys)
+ assert self.count == len(entities)
+ assert direction in (key_range_module.KeyRange.ASC,
+ key_range_module.KeyRange.DESC)
+ if self.count > 0:
+ if direction == key_range_module.KeyRange.ASC:
+ self.key_start = keys[0]
+ self.key_end = keys[-1]
+ else:
+ self.key_start = keys[-1]
+ self.key_end = keys[0]
+
+ def Entities(self):
+ """Returns the list of entities for this result in key order."""
+ if self.direction == key_range_module.KeyRange.ASC:
+ return list(self.entities)
+ else:
+ result = list(self.entities)
+ result.reverse()
+ return result
+
+ def __str__(self):
+ return 'continued = %s\n%s' % (
+ str(self.continued), '\n'.join(str(self.entities)))
+
+
+class _WorkItem(adaptive_thread_pool.WorkItem):
+ """Holds a description of a unit of upload or download work."""
+
+ def __init__(self, progress_queue, key_start, key_end, state_namer,
+ state=STATE_READ, progress_key=None):
+ """Initialize the _WorkItem instance.
+
+ Args:
+ progress_queue: A queue used for tracking progress information.
+ key_start: The start key of the work item.
+ key_end: The end key of the work item.
+ state_namer: Function to describe work item states.
+ state: The initial state of the work item.
+ progress_key: If this WorkItem represents state from a prior run,
+ then this will be the key within the progress database.
+ """
+ adaptive_thread_pool.WorkItem.__init__(self,
+ '[%s-%s]' % (key_start, key_end))
+ self.progress_queue = progress_queue
+ self.state_namer = state_namer
+ self.state = state
+ self.progress_key = progress_key
+ self.progress_event = threading.Event()
+ self.key_start = key_start
+ self.key_end = key_end
+ self.error = None
+ self.traceback = None
+
+ def _TransferItem(self, thread_pool):
+ raise NotImplementedError()
+
+ def SetError(self):
+ """Sets the error and traceback information for this thread.
+
+ This must be called from an exception handler.
+ """
+ if not self.error:
+ exc_info = sys.exc_info()
+ self.error = exc_info[1]
+ self.traceback = exc_info[2]
+
+ def PerformWork(self, thread_pool):
+ """Perform the work of this work item and report the results.
+
+ Args:
+ thread_pool: An AdaptiveThreadPool instance.
+
+ Returns:
+ A tuple (status, instruction) of the work status and an instruction
+ for the ThreadGate.
+ """
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ instruction = adaptive_thread_pool.ThreadGate.DECREASE
+
+ try:
+ self.MarkAsTransferring()
+
+ try:
+ transfer_time = self._TransferItem(thread_pool)
+ if transfer_time is None:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ instruction = adaptive_thread_pool.ThreadGate.HOLD
+ else:
+ logger.debug('[%s] %s Transferred %d entities in %0.1f seconds',
+ threading.currentThread().getName(), self, self.count,
+ transfer_time)
+ sys.stdout.write('.')
+ sys.stdout.flush()
+ status = adaptive_thread_pool.WorkItem.SUCCESS
+ if transfer_time <= MAXIMUM_INCREASE_DURATION:
+ instruction = adaptive_thread_pool.ThreadGate.INCREASE
+ elif transfer_time <= MAXIMUM_HOLD_DURATION:
+ instruction = adaptive_thread_pool.ThreadGate.HOLD
+ except (db.InternalError, db.NotSavedError, db.Timeout,
+ db.TransactionFailedError,
+ apiproxy_errors.OverQuotaError,
+ apiproxy_errors.DeadlineExceededError,
+ apiproxy_errors.ApplicationError), e:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal datastore error: %s', e)
+ except urllib2.HTTPError, e:
+ http_status = e.code
+ if http_status == 403 or (http_status >= 500 and http_status < 600):
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal HTTP error: %d %s',
+ http_status, e.msg)
+ else:
+ self.SetError()
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ except urllib2.URLError, e:
+ if IsURLErrorFatal(e):
+ self.SetError()
+ status = adaptive_thread_pool.WorkItem.FAILURE
+ else:
+ status = adaptive_thread_pool.WorkItem.RETRY
+ logger.exception('Retrying on non-fatal URL error: %s', e.reason)
+
+ finally:
+ if status == adaptive_thread_pool.WorkItem.SUCCESS:
+ self.MarkAsTransferred()
+ else:
+ self.MarkAsError()
+
+ return (status, instruction)
+
+ def _AssertInState(self, *states):
+ """Raises an Error if the state of this range is not in states."""
+ if not self.state in states:
+ raise BadStateError('%s:%s not in %s' %
+ (str(self),
+ self.state_namer(self.state),
+ map(self.state_namer, states)))
+
+ def _AssertProgressKey(self):
+ """Raises an Error if the progress key is None."""
+ if self.progress_key is None:
+ raise BadStateError('%s: Progress key is missing' % str(self))
+
+ def MarkAsRead(self):
+ """Mark this _WorkItem as read, updating the progress database."""
+ self._AssertInState(STATE_READ)
+ self._StateTransition(STATE_READ, blocking=True)
+
+ def MarkAsTransferring(self):
+ """Mark this _WorkItem as transferring, updating the progress database."""
+ self._AssertInState(STATE_READ, STATE_ERROR)
+ self._AssertProgressKey()
+ self._StateTransition(STATE_GETTING, blocking=True)
+
+ def MarkAsTransferred(self):
+ """Mark this _WorkItem as transferred, updating the progress database."""
+ raise NotImplementedError()
+
+ def MarkAsError(self):
+ """Mark this _WorkItem as failed, updating the progress database."""
+ self._AssertInState(STATE_GETTING)
+ self._AssertProgressKey()
+ self._StateTransition(STATE_ERROR, blocking=True)
+
+ def _StateTransition(self, new_state, blocking=False):
+ """Transition the work item to a new state, storing progress information.
+
+ Args:
+ new_state: The state to transition to.
+ blocking: Whether to block for the progress thread to acknowledge the
+ transition.
+ """
+ assert not self.progress_event.isSet()
+
+ self.state = new_state
+
+ self.progress_queue.put(self)
+
+ if blocking:
+ self.progress_event.wait()
+
+ self.progress_event.clear()
+
+
+
+class UploadWorkItem(_WorkItem):
+ """Holds a unit of uploading work.
+
+ A UploadWorkItem represents a number of entities that need to be uploaded to
+ Google App Engine. These entities are encoded in the "content" field of
+ the UploadWorkItem, and will be POST'd as-is to the server.
+
+ The entities are identified by a range of numeric keys, inclusively. In
+ the case of a resumption of an upload, or a replay to correct errors,
+ these keys must be able to identify the same set of entities.
+
+ Note that keys specify a range. The entities do not have to sequentially
+ fill the entire range, they must simply bound a range of valid keys.
+ """
+
+ def __init__(self, request_manager, progress_queue, rows, key_start, key_end,
+ progress_key=None):
+ """Initialize the UploadWorkItem instance.
+
+ Args:
+ request_manager: A RequestManager instance.
+ progress_queue: A queue used for tracking progress information.
+ rows: A list of pairs of a line number and a list of column values
+ key_start: The (numeric) starting key, inclusive.
+ key_end: The (numeric) ending key, inclusive.
+ progress_key: If this UploadWorkItem represents state from a prior run,
+ then this will be the key within the progress database.
+ """
+ _WorkItem.__init__(self, progress_queue, key_start, key_end,
+ ImportStateName, state=STATE_READ,
+ progress_key=progress_key)
+
+ assert isinstance(key_start, (int, long))
+ assert isinstance(key_end, (int, long))
+ assert key_start <= key_end
+
+ self.request_manager = request_manager
+ self.rows = rows
+ self.content = None
+ self.count = len(rows)
+
+ def __str__(self):
+ return '[%s-%s]' % (self.key_start, self.key_end)
+
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ """Transfers the entities associated with an item.
+
+ Args:
+ thread_pool: An AdaptiveThreadPool instance.
+ get_time: Used for dependency injection.
+ """
+ t = get_time()
+ if not self.content:
+ self.content = self.request_manager.EncodeContent(self.rows)
+ try:
+ self.request_manager.PostEntities(self.content)
+ except:
+ raise
+ return get_time() - t
+
+ def MarkAsTransferred(self):
+ """Mark this UploadWorkItem as sucessfully-sent to the server."""
+
+ self._AssertInState(STATE_SENDING)
+ self._AssertProgressKey()
+
+ self._StateTransition(STATE_SENT, blocking=False)
+
+
+def GetImplementationClass(kind_or_class_key):
+ """Returns the implementation class for a given kind or class key.
+
+ Args:
+ kind_or_class_key: A kind string or a tuple of kind strings.
+
+ Return:
+ A db.Model subclass for the given kind or class key.
+ """
+ if isinstance(kind_or_class_key, tuple):
+ try:
+ implementation_class = polymodel._class_map[kind_or_class_key]
+ except KeyError:
+ raise db.KindError('No implementation for class \'%s\'' %
+ kind_or_class_key)
+ else:
+ implementation_class = db.class_for_kind(kind_or_class_key)
+ return implementation_class
+
+
+def KeyLEQ(key1, key2):
+ """Compare two keys for less-than-or-equal-to.
+
+ All keys with numeric ids come before all keys with names. None represents
+ an unbounded end-point so it is both greater and less than any other key.
+
+ Args:
+ key1: An int or datastore.Key instance.
+ key2: An int or datastore.Key instance.
+
+ Returns:
+ True if key1 <= key2
+ """
+ if key1 is None or key2 is None:
+ return True
+ return key1 <= key2
+
+
+class KeyRangeItem(_WorkItem):
+ """Represents an item of work that scans over a key range.
+
+ A KeyRangeItem object represents holds a KeyRange
+ and has an associated state: STATE_READ, STATE_GETTING, STATE_GOT,
+ and STATE_ERROR.
+
+ - STATE_READ indicates the range ready to be downloaded by a worker thread.
+ - STATE_GETTING indicates the range is currently being downloaded.
+ - STATE_GOT indicates that the range was successfully downloaded
+ - STATE_ERROR indicates that an error occurred during the last download
+ attempt
+
+ KeyRangeItems not in the STATE_GOT state are stored in the progress database.
+ When a piece of KeyRangeItem work is downloaded, the download may cover only
+ a portion of the range. In this case, the old KeyRangeItem is removed from
+ the progress database and ranges covering the undownloaded range are
+ generated and stored as STATE_READ in the export progress database.
+ """
+
+ def __init__(self,
+ request_manager,
+ progress_queue,
+ kind,
+ key_range,
+ progress_key=None,
+ state=STATE_READ):
+ """Initialize a KeyRangeItem object.
+
+ Args:
+ request_manager: A RequestManager instance.
+ progress_queue: A queue used for tracking progress information.
+ kind: The kind of entities for this range.
+ key_range: A KeyRange instance for this work item.
+ progress_key: The key for this range within the progress database.
+ state: The initial state of this range.
+ """
+ _WorkItem.__init__(self, progress_queue, key_range.key_start,
+ key_range.key_end, ExportStateName, state=state,
+ progress_key=progress_key)
+ self.request_manager = request_manager
+ self.kind = kind
+ self.key_range = key_range
+ self.download_result = None
+ self.count = 0
+ self.key_start = key_range.key_start
+ self.key_end = key_range.key_end
+
+ def __str__(self):
+ return str(self.key_range)
+
+ def __repr__(self):
+ return self.__str__()
+
+ def MarkAsTransferred(self):
+ """Mark this KeyRangeItem as transferred, updating the progress database."""
+ pass
+
+ def Process(self, download_result, thread_pool, batch_size,
+ new_state=STATE_GOT):
+ """Mark this KeyRangeItem as success, updating the progress database.
+
+ Process will split this KeyRangeItem based on the content of
+ download_result and adds the unfinished ranges to the work queue.
+
+ Args:
+ download_result: A DownloadResult instance.
+ thread_pool: An AdaptiveThreadPool instance.
+ batch_size: The number of entities to transfer per request.
+ new_state: The state to transition the completed range to.
+ """
+ self._AssertInState(STATE_GETTING)
+ self._AssertProgressKey()
+
+ self.download_result = download_result
+ self.count = len(download_result.keys)
+ if download_result.continued:
+ self._FinishedRange()._StateTransition(new_state, blocking=True)
+ self._AddUnfinishedRanges(thread_pool, batch_size)
+ else:
+ self._StateTransition(new_state, blocking=True)
+
+ def _FinishedRange(self):
+ """Returns the range completed by the download_result.
+
+ Returns:
+ A KeyRangeItem representing a completed range.
+ """
+ assert self.download_result is not None
+
+ if self.key_range.direction == key_range_module.KeyRange.ASC:
+ key_start = self.key_range.key_start
+ if self.download_result.continued:
+ key_end = self.download_result.key_end
+ else:
+ key_end = self.key_range.key_end
+ else:
+ key_end = self.key_range.key_end
+ if self.download_result.continued:
+ key_start = self.download_result.key_start
+ else:
+ key_start = self.key_range.key_start
+
+ key_range = KeyRange(key_start=key_start,
+ key_end=key_end,
+ direction=self.key_range.direction)
+
+ result = self.__class__(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range,
+ progress_key=self.progress_key,
+ state=self.state)
+
+ result.download_result = self.download_result
+ result.count = self.count
+ return result
+
+ def _SplitAndAddRanges(self, thread_pool, batch_size):
+ """Split the key range [key_start, key_end] into a list of ranges."""
+ if self.download_result.direction == key_range_module.KeyRange.ASC:
+ key_range = KeyRange(
+ key_start=self.download_result.key_end,
+ key_end=self.key_range.key_end,
+ include_start=False)
+ else:
+ key_range = KeyRange(
+ key_start=self.key_range.key_start,
+ key_end=self.download_result.key_start,
+ include_end=False)
+
+ if thread_pool.QueuedItemCount() > 2 * thread_pool.num_threads():
+ ranges = [key_range]
+ else:
+ ranges = key_range.split_range(batch_size=batch_size)
+
+ for key_range in ranges:
+ key_range_item = self.__class__(self.request_manager,
+ self.progress_queue,
+ self.kind,
+ key_range)
+ key_range_item.MarkAsRead()
+ thread_pool.SubmitItem(key_range_item, block=True)
+
+ def _AddUnfinishedRanges(self, thread_pool, batch_size):
+ """Adds incomplete KeyRanges to the thread_pool.
+
+ Args:
+ thread_pool: An AdaptiveThreadPool instance.
+ batch_size: The number of entities to transfer per request.
+
+ Returns:
+ A list of KeyRanges representing incomplete datastore key ranges.
+
+ Raises:
+ KeyRangeError: if this key range has already been completely transferred.
+ """
+ assert self.download_result is not None
+ if self.download_result.continued:
+ self._SplitAndAddRanges(thread_pool, batch_size)
+ else:
+ raise KeyRangeError('No unfinished part of key range.')
+
+
+class DownloadItem(KeyRangeItem):
+ """A KeyRangeItem for downloading key ranges."""
+
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ """Transfers the entities associated with an item."""
+ t = get_time()
+ download_result = self.request_manager.GetEntities(self)
+ transfer_time = get_time() - t
+ self.Process(download_result, thread_pool,
+ self.request_manager.batch_size)
+ return transfer_time
+
+
+class MapperItem(KeyRangeItem):
+ """A KeyRangeItem for mapping over key ranges."""
+
+ def _TransferItem(self, thread_pool, get_time=time.time):
+ t = get_time()
+ download_result = self.request_manager.GetEntities(self)
+ transfer_time = get_time() - t
+ mapper = self.request_manager.GetMapper()
+ try:
+ mapper.batch_apply(download_result.Entities())
+ except MapperRetry:
+ return None
+ self.Process(download_result, thread_pool,
+ self.request_manager.batch_size)
+ return transfer_time
+
+
+class RequestManager(object):
+ """A class which wraps a connection to the server."""
+
+ def __init__(self,
+ app_id,
+ host_port,
+ url_path,
+ kind,
+ throttle,
+ batch_size,
+ secure,
+ email,
+ passin,
+ dry_run=False):
+ """Initialize a RequestManager object.
+
+ Args:
+ app_id: String containing the application id for requests.
+ host_port: String containing the "host:port" pair; the port is optional.
+ url_path: partial URL (path) to post entity data to.
+ kind: Kind of the Entity records being posted.
+ throttle: A Throttle instance.
+ batch_size: The number of entities to transfer per request.
+ secure: Use SSL when communicating with server.
+ email: If not none, the username to log in with.
+ passin: If True, the password will be read from standard in.
+ """
+ self.app_id = app_id
+ self.host_port = host_port
+ self.host = host_port.split(':')[0]
+ if url_path and url_path[0] != '/':
+ url_path = '/' + url_path
+ self.url_path = url_path
+ self.kind = kind
+ self.throttle = throttle
+ self.batch_size = batch_size
+ self.secure = secure
+ self.authenticated = False
+ self.auth_called = False
+ self.parallel_download = True
+ self.email = email
+ self.passin = passin
+ self.mapper = None
+ self.dry_run = dry_run
+
+ if self.dry_run:
+ logger.info('Running in dry run mode, skipping remote_api setup')
+ return
+
+ logger.debug('Configuring remote_api. url_path = %s, '
+ 'servername = %s' % (url_path, host_port))
+
+ def CookieHttpRpcServer(*args, **kwargs):
+ kwargs['save_cookies'] = True
+ kwargs['account_type'] = 'HOSTED_OR_GOOGLE'
+ return appengine_rpc.HttpRpcServer(*args, **kwargs)
+
+ remote_api_stub.ConfigureRemoteDatastore(
+ app_id,
+ url_path,
+ self.AuthFunction,
+ servername=host_port,
+ rpc_server_factory=CookieHttpRpcServer,
+ secure=self.secure)
+ remote_api_throttle.ThrottleRemoteDatastore(self.throttle)
+ logger.debug('Bulkloader using app_id: %s', os.environ['APPLICATION_ID'])
+
+ def Authenticate(self):
+ """Invoke authentication if necessary."""
+ logger.info('Connecting to %s%s', self.host_port, self.url_path)
+ if self.dry_run:
+ self.authenticated = True
+ return
+
+ remote_api_stub.MaybeInvokeAuthentication()
+ self.authenticated = True
+
+ def AuthFunction(self,
+ raw_input_fn=raw_input,
+ password_input_fn=getpass.getpass):
+ """Prompts the user for a username and password.
+
+ Caches the results the first time it is called and returns the
+ same result every subsequent time.
+
+ Args:
+ raw_input_fn: Used for dependency injection.
+ password_input_fn: Used for dependency injection.
+
+ Returns:
+ A pair of the username and password.
+ """
+ if self.email:
+ email = self.email
+ else:
+ print 'Please enter login credentials for %s' % (
+ self.host)
+ email = raw_input_fn('Email: ')
+
+ if email:
+ password_prompt = 'Password for %s: ' % email
+ if self.passin:
+ password = raw_input_fn(password_prompt)
+ else:
+ password = password_input_fn(password_prompt)
+ else:
+ password = None
+
+ self.auth_called = True
+ return (email, password)
+
+ def EncodeContent(self, rows, loader=None):
+ """Encodes row data to the wire format.
+
+ Args:
+ rows: A list of pairs of a line number and a list of column values.
+ loader: Used for dependency injection.
+
+ Returns:
+ A list of datastore.Entity instances.
+
+ Raises:
+ ConfigurationError: if no loader is defined for self.kind
+ """
+ if not loader:
+ try:
+ loader = Loader.RegisteredLoader(self.kind)
+ except KeyError:
+ logger.error('No Loader defined for kind %s.' % self.kind)
+ raise ConfigurationError('No Loader defined for kind %s.' % self.kind)
+ entities = []
+ for line_number, values in rows:
+ key = loader.generate_key(line_number, values)
+ if isinstance(key, datastore.Key):
+ parent = key.parent()
+ key = key.name()
+ else:
+ parent = None
+ entity = loader.create_entity(values, key_name=key, parent=parent)
+
+ def ToEntity(entity):
+ if isinstance(entity, db.Model):
+ return entity._populate_entity()
+ else:
+ return entity
+
+ if isinstance(entity, list):
+ entities.extend(map(ToEntity, entity))
+ elif entity:
+ entities.append(ToEntity(entity))
+
+ return entities
+
+ def PostEntities(self, entities):
+ """Posts Entity records to a remote endpoint over HTTP.
+
+ Args:
+ entities: A list of datastore entities.
+ """
+ if self.dry_run:
+ return
+ datastore.Put(entities)
+
+ def _QueryForPbs(self, query):
+ """Perform the given query and return a list of entity_pb's."""
+ try:
+ query_pb = query._ToPb(limit=self.batch_size)
+ result_pb = datastore_pb.QueryResult()
+ apiproxy_stub_map.MakeSyncCall('datastore_v3', 'RunQuery', query_pb,
+ result_pb)
+ next_pb = datastore_pb.NextRequest()
+ next_pb.set_count(self.batch_size)
+ next_pb.mutable_cursor().CopyFrom(result_pb.cursor())
+ result_pb = datastore_pb.QueryResult()
+ apiproxy_stub_map.MakeSyncCall('datastore_v3', 'Next', next_pb, result_pb)
+ return result_pb.result_list()
+ except apiproxy_errors.ApplicationError, e:
+ raise datastore._ToDatastoreError(e)
+
+ def GetEntities(self, key_range_item, key_factory=datastore.Key):
+ """Gets Entity records from a remote endpoint over HTTP.
+
+ Args:
+ key_range_item: Range of keys to get.
+ key_factory: Used for dependency injection.
+
+ Returns:
+ A DownloadResult instance.
+
+ Raises:
+ ConfigurationError: if no Exporter is defined for self.kind
+ """
+ keys = []
+ entities = []
+
+ if self.parallel_download:
+ query = key_range_item.key_range.make_directed_datastore_query(self.kind)
+ try:
+ results = self._QueryForPbs(query)
+ except datastore_errors.NeedIndexError:
+ logger.info('%s: No descending index on __key__, '
+ 'performing serial download', self.kind)
+ self.parallel_download = False
+
+ if not self.parallel_download:
+ key_range_item.key_range.direction = key_range_module.KeyRange.ASC
+ query = key_range_item.key_range.make_ascending_datastore_query(self.kind)
+ results = self._QueryForPbs(query)
+
+ size = len(results)
+
+ for entity in results:
+ key = key_factory()
+ key._Key__reference = entity.key()
+ entities.append(entity)
+ keys.append(key)
+
+ continued = (size == self.batch_size)
+ key_range_item.count = size
+
+ return DownloadResult(continued, key_range_item.key_range.direction,
+ keys, entities)
+
+ def GetMapper(self):
+ """Returns a mapper for the registered kind.
+
+ Returns:
+ A Mapper instance.
+
+ Raises:
+ ConfigurationError: if no Mapper is defined for self.kind
+ """
+ if not self.mapper:
+ try:
+ self.mapper = Mapper.RegisteredMapper(self.kind)
+ except KeyError:
+ logger.error('No Mapper defined for kind %s.' % self.kind)
+ raise ConfigurationError('No Mapper defined for kind %s.' % self.kind)
+ return self.mapper
+
+
+def InterruptibleSleep(sleep_time):
+ """Puts thread to sleep, checking this threads exit_flag twice a second.
+
+ Args:
+ sleep_time: Time to sleep.
+ """
+ slept = 0.0
+ epsilon = .0001
+ thread = threading.currentThread()
+ while slept < sleep_time - epsilon:
+ remaining = sleep_time - slept
+ this_sleep_time = min(remaining, 0.5)
+ time.sleep(this_sleep_time)
+ slept += this_sleep_time
+ if thread.exit_flag:
+ return
+
+
+class _ThreadBase(threading.Thread):
+ """Provide some basic features for the threads used in the uploader.
+
+ This abstract base class is used to provide some common features:
+
+ * Flag to ask thread to exit as soon as possible.
+ * Record exit/error status for the primary thread to pick up.
+ * Capture exceptions and record them for pickup.
+ * Some basic logging of thread start/stop.
+ * All threads are "daemon" threads.
+ * Friendly names for presenting to users.
+
+ Concrete sub-classes must implement PerformWork().
+
+ Either self.NAME should be set or GetFriendlyName() be overridden to
+ return a human-friendly name for this thread.
+
+ The run() method starts the thread and prints start/exit messages.
+
+ self.exit_flag is intended to signal that this thread should exit
+ when it gets the chance. PerformWork() should check self.exit_flag
+ whenever it has the opportunity to exit gracefully.
+ """
+
+ def __init__(self):
+ threading.Thread.__init__(self)
+
+ self.setDaemon(True)
+
+ self.exit_flag = False
+ self.error = None
+ self.traceback = None
+
+ def run(self):
+ """Perform the work of the thread."""
+ logger.debug('[%s] %s: started', self.getName(), self.__class__.__name__)
+
+ try:
+ self.PerformWork()
+ except:
+ self.SetError()
+ logger.exception('[%s] %s:', self.getName(), self.__class__.__name__)
+
+ logger.debug('[%s] %s: exiting', self.getName(), self.__class__.__name__)
+
+ def SetError(self):
+ """Sets the error and traceback information for this thread.
+
+ This must be called from an exception handler.
+ """
+ if not self.error:
+ exc_info = sys.exc_info()
+ self.error = exc_info[1]
+ self.traceback = exc_info[2]
+
+ def PerformWork(self):
+ """Perform the thread-specific work."""
+ raise NotImplementedError()
+
+ def CheckError(self):
+ """If an error is present, then log it."""
+ if self.error:
+ logger.error('Error in %s: %s', self.GetFriendlyName(), self.error)
+ if self.traceback:
+ logger.debug(''.join(traceback.format_exception(self.error.__class__,
+ self.error,
+ self.traceback)))
+
+ def GetFriendlyName(self):
+ """Returns a human-friendly description of the thread."""
+ if hasattr(self, 'NAME'):
+ return self.NAME
+ return 'unknown thread'
+
+
+non_fatal_error_codes = set([errno.EAGAIN,
+ errno.ENETUNREACH,
+ errno.ENETRESET,
+ errno.ECONNRESET,
+ errno.ETIMEDOUT,
+ errno.EHOSTUNREACH])
+
+
+def IsURLErrorFatal(error):
+ """Returns False if the given URLError may be from a transient failure.
+
+ Args:
+ error: A urllib2.URLError instance.
+ """
+ assert isinstance(error, urllib2.URLError)
+ if not hasattr(error, 'reason'):
+ return True
+ if not isinstance(error.reason[0], int):
+ return True
+ return error.reason[0] not in non_fatal_error_codes
+
+
+class DataSourceThread(_ThreadBase):
+ """A thread which reads WorkItems and pushes them into queue.
+
+ This thread will read/consume WorkItems from a generator (produced by
+ the generator factory). These WorkItems will then be pushed into the
+ thread_pool. Note that reading will block if/when the thread_pool becomes
+ full. Information on content consumed from the generator will be pushed
+ into the progress_queue.
+ """
+
+ NAME = 'data source thread'
+
+ def __init__(self,
+ request_manager,
+ thread_pool,
+ progress_queue,
+ workitem_generator_factory,
+ progress_generator_factory):
+ """Initialize the DataSourceThread instance.
+
+ Args:
+ request_manager: A RequestManager instance.
+ thread_pool: An AdaptiveThreadPool instance.
+ progress_queue: A queue used for tracking progress information.
+ workitem_generator_factory: A factory that creates a WorkItem generator
+ progress_generator_factory: A factory that creates a generator which
+ produces prior progress status, or None if there is no prior status
+ to use.
+ """
+ _ThreadBase.__init__(self)
+
+ self.request_manager = request_manager
+ self.thread_pool = thread_pool
+ self.progress_queue = progress_queue
+ self.workitem_generator_factory = workitem_generator_factory
+ self.progress_generator_factory = progress_generator_factory
+ self.entity_count = 0
+
+ def PerformWork(self):
+ """Performs the work of a DataSourceThread."""
+ if self.progress_generator_factory:
+ progress_gen = self.progress_generator_factory()
+ else:
+ progress_gen = None
+
+ content_gen = self.workitem_generator_factory(self.request_manager,
+ self.progress_queue,
+ progress_gen)
+
+ self.xfer_count = 0
+ self.read_count = 0
+ self.read_all = False
+
+ for item in content_gen.Batches():
+ item.MarkAsRead()
+
+ while not self.exit_flag:
+ try:
+ self.thread_pool.SubmitItem(item, block=True, timeout=1.0)
+ self.entity_count += item.count
+ break
+ except Queue.Full:
+ pass
+
+ if self.exit_flag:
+ break
+
+ if not self.exit_flag:
+ self.read_all = True
+ self.read_count = content_gen.row_count
+ self.xfer_count = content_gen.xfer_count
+
+
+
+def _RunningInThread(thread):
+ """Return True if we are running within the specified thread."""
+ return threading.currentThread().getName() == thread.getName()
+
+
+class _Database(object):
+ """Base class for database connections in this module.
+
+ The table is created by a primary thread (the python main thread)
+ but all future lookups and updates are performed by a secondary
+ thread.
+ """
+
+ SIGNATURE_TABLE_NAME = 'bulkloader_database_signature'
+
+ def __init__(self,
+ db_filename,
+ create_table,
+ signature,
+ index=None,
+ commit_periodicity=100):
+ """Initialize the _Database instance.
+
+ Args:
+ db_filename: The sqlite3 file to use for the database.
+ create_table: A string containing the SQL table creation command.
+ signature: A string identifying the important invocation options,
+ used to make sure we are not using an old database.
+ index: An optional string to create an index for the database.
+ commit_periodicity: Number of operations between database commits.
+ """
+ self.db_filename = db_filename
+
+ logger.info('Opening database: %s', db_filename)
+ self.primary_conn = sqlite3.connect(db_filename, isolation_level=None)
+ self.primary_thread = threading.currentThread()
+
+ self.secondary_conn = None
+ self.secondary_thread = None
+
+ self.operation_count = 0
+ self.commit_periodicity = commit_periodicity
+
+ try:
+ self.primary_conn.execute(create_table)
+ except sqlite3.OperationalError, e:
+ if 'already exists' not in e.message:
+ raise
+
+ if index:
+ try:
+ self.primary_conn.execute(index)
+ except sqlite3.OperationalError, e:
+ if 'already exists' not in e.message:
+ raise
+
+ self.existing_table = False
+ signature_cursor = self.primary_conn.cursor()
+ create_signature = """
+ create table %s (
+ value TEXT not null)
+ """ % _Database.SIGNATURE_TABLE_NAME
+ try:
+ self.primary_conn.execute(create_signature)
+ self.primary_conn.cursor().execute(
+ 'insert into %s (value) values (?)' % _Database.SIGNATURE_TABLE_NAME,
+ (signature,))
+ except sqlite3.OperationalError, e:
+ if 'already exists' not in e.message:
+ logger.exception('Exception creating table:')
+ raise
+ else:
+ self.existing_table = True
+ signature_cursor.execute(
+ 'select * from %s' % _Database.SIGNATURE_TABLE_NAME)
+ (result,) = signature_cursor.fetchone()
+ if result and result != signature:
+ logger.error('Database signature mismatch:\n\n'
+ 'Found:\n'
+ '%s\n\n'
+ 'Expecting:\n'
+ '%s\n',
+ result, signature)
+ raise ResumeError('Database signature mismatch: %s != %s' % (
+ signature, result))
+
+ def ThreadComplete(self):
+ """Finalize any operations the secondary thread has performed.
+
+ The database aggregates lots of operations into a single commit, and
+ this method is used to commit any pending operations as the thread
+ is about to shut down.
+ """
+ if self.secondary_conn:
+ self._MaybeCommit(force_commit=True)
+
+ def _MaybeCommit(self, force_commit=False):
+ """Periodically commit changes into the SQLite database.
+
+ Committing every operation is quite expensive, and slows down the
+ operation of the script. Thus, we only commit after every N operations,
+ as determined by the self.commit_periodicity value. Optionally, the
+ caller can force a commit.
+
+ Args:
+ force_commit: Pass True in order for a commit to occur regardless
+ of the current operation count.
+ """
+ self.operation_count += 1
+ if force_commit or (self.operation_count % self.commit_periodicity) == 0:
+ self.secondary_conn.commit()
+
+ def _OpenSecondaryConnection(self):
+ """Possibly open a database connection for the secondary thread.
+
+ If the connection is not open (for the calling thread, which is assumed
+ to be the unique secondary thread), then open it. We also open a couple
+ cursors for later use (and reuse).
+ """
+ if self.secondary_conn:
+ return
+
+ assert not _RunningInThread(self.primary_thread)
+
+ self.secondary_thread = threading.currentThread()
+
+ self.secondary_conn = sqlite3.connect(self.db_filename)
+
+ self.insert_cursor = self.secondary_conn.cursor()
+ self.update_cursor = self.secondary_conn.cursor()
+
+
+zero_matcher = re.compile(r'\x00')
+
+zero_one_matcher = re.compile(r'\x00\x01')
+
+
+def KeyStr(key):
+ """Returns a string to represent a key, preserving ordering.
+
+ Unlike datastore.Key.__str__(), we have the property:
+
+ key1 < key2 ==> KeyStr(key1) < KeyStr(key2)
+
+ The key string is constructed from the key path as follows:
+ (1) Strings are prepended with ':' and numeric id's are padded to
+ 20 digits.
+ (2) Any null characters (u'\0') present are replaced with u'\0\1'
+ (3) The sequence u'\0\0' is used to separate each component of the path.
+
+ (1) assures that names and ids compare properly, while (2) and (3) enforce
+ the part-by-part comparison of pieces of the path.
+
+ Args:
+ key: A datastore.Key instance.
+
+ Returns:
+ A string representation of the key, which preserves ordering.
+ """
+ assert isinstance(key, datastore.Key)
+ path = key.to_path()
+
+ out_path = []
+ for part in path:
+ if isinstance(part, (int, long)):
+ part = '%020d' % part
+ else:
+ part = ':%s' % part
+
+ out_path.append(zero_matcher.sub(u'\0\1', part))
+
+ out_str = u'\0\0'.join(out_path)
+
+ return out_str
+
+
+def StrKey(key_str):
+ """The inverse of the KeyStr function.
+
+ Args:
+ key_str: A string in the range of KeyStr.
+
+ Returns:
+ A datastore.Key instance k, such that KeyStr(k) == key_str.
+ """
+ parts = key_str.split(u'\0\0')
+ for i in xrange(len(parts)):
+ if parts[i][0] == ':':
+ part = parts[i][1:]
+ part = zero_one_matcher.sub(u'\0', part)
+ parts[i] = part
+ else:
+ parts[i] = int(parts[i])
+ return datastore.Key.from_path(*parts)
+
+
+class ResultDatabase(_Database):
+ """Persistently record all the entities downloaded during an export.
+
+ The entities are held in the database by their unique datastore key
+ in order to avoid duplication if an export is restarted.
+ """
+
+ def __init__(self, db_filename, signature, commit_periodicity=1):
+ """Initialize a ResultDatabase object.
+
+ Args:
+ db_filename: The name of the SQLite database to use.
+ signature: A string identifying the important invocation options,
+ used to make sure we are not using an old database.
+ commit_periodicity: How many operations to perform between commits.
+ """
+ self.complete = False
+ create_table = ('create table result (\n'
+ 'id BLOB primary key,\n'
+ 'value BLOB not null)')
+
+ _Database.__init__(self,
+ db_filename,
+ create_table,
+ signature,
+ commit_periodicity=commit_periodicity)
+ if self.existing_table:
+ cursor = self.primary_conn.cursor()
+ cursor.execute('select count(*) from result')
+ self.existing_count = int(cursor.fetchone()[0])
+ else:
+ self.existing_count = 0
+ self.count = self.existing_count
+
+ def _StoreEntity(self, entity_id, entity):
+ """Store an entity in the result database.
+
+ Args:
+ entity_id: A datastore.Key for the entity.
+ entity: The entity to store.
+
+ Returns:
+ True if this entities is not already present in the result database.
+ """
+
+ assert _RunningInThread(self.secondary_thread)
+ assert isinstance(entity_id, datastore.Key), (
+ 'expected a datastore.Key, got a %s' % entity_id.__class__.__name__)
+
+ key_str = buffer(KeyStr(entity_id).encode('utf-8'))
+ self.insert_cursor.execute(
+ 'select count(*) from result where id = ?', (key_str,))
+
+ already_present = self.insert_cursor.fetchone()[0]
+ result = True
+ if already_present:
+ result = False
+ self.insert_cursor.execute('delete from result where id = ?',
+ (key_str,))
+ else:
+ self.count += 1
+ value = entity.Encode()
+ self.insert_cursor.execute(
+ 'insert into result (id, value) values (?, ?)',
+ (key_str, buffer(value)))
+ return result
+
+ def StoreEntities(self, keys, entities):
+ """Store a group of entities in the result database.
+
+ Args:
+ keys: A list of entity keys.
+ entities: A list of entities.
+
+ Returns:
+ The number of new entities stored in the result database.
+ """
+ self._OpenSecondaryConnection()
+ t = time.time()
+ count = 0
+ for entity_id, entity in zip(keys,
+ entities):
+ if self._StoreEntity(entity_id, entity):
+ count += 1
+ logger.debug('%s insert: delta=%.3f',
+ self.db_filename,
+ time.time() - t)
+ logger.debug('Entities transferred total: %s', self.count)
+ self._MaybeCommit()
+ return count
+
+ def ResultsComplete(self):
+ """Marks the result database as containing complete results."""
+ self.complete = True
+
+ def AllEntities(self):
+ """Yields all pairs of (id, value) from the result table."""
+ conn = sqlite3.connect(self.db_filename, isolation_level=None)
+ cursor = conn.cursor()
+
+ cursor.execute(
+ 'select id, value from result order by id')
+
+ for unused_entity_id, entity in cursor:
+ entity_proto = entity_pb.EntityProto(contents=entity)
+ yield datastore.Entity._FromPb(entity_proto)
+
+
+class _ProgressDatabase(_Database):
+ """Persistently record all progress information during an upload.
+
+ This class wraps a very simple SQLite database which records each of
+ the relevant details from a chunk of work. If the loader is
+ resumed, then data is replayed out of the database.
+ """
+
+ def __init__(self,
+ db_filename,
+ sql_type,
+ py_type,
+ signature,
+ commit_periodicity=100):
+ """Initialize the ProgressDatabase instance.
+
+ Args:
+ db_filename: The name of the SQLite database to use.
+ sql_type: A string of the SQL type to use for entity keys.
+ py_type: The python type of entity keys.
+ signature: A string identifying the important invocation options,
+ used to make sure we are not using an old database.
+ commit_periodicity: How many operations to perform between commits.
+ """
+ self.prior_key_end = None
+
+ create_table = ('create table progress (\n'
+ 'id integer primary key autoincrement,\n'
+ 'state integer not null,\n'
+ 'key_start %s,\n'
+ 'key_end %s)'
+ % (sql_type, sql_type))
+ self.py_type = py_type
+
+ index = 'create index i_state on progress (state)'
+ _Database.__init__(self,
+ db_filename,
+ create_table,
+ signature,
+ index=index,
+ commit_periodicity=commit_periodicity)
+
+ def UseProgressData(self):
+ """Returns True if the database has progress information.
+
+ Note there are two basic cases for progress information:
+ 1) All saved records indicate a successful upload. In this case, we
+ need to skip everything transmitted so far and then send the rest.
+ 2) Some records for incomplete transfer are present. These need to be
+ sent again, and then we resume sending after all the successful
+ data.
+
+ Returns:
+ True: if the database has progress information.
+
+ Raises:
+ ResumeError: if there is an error retrieving rows from the database.
+ """
+ assert _RunningInThread(self.primary_thread)
+
+ cursor = self.primary_conn.cursor()
+ cursor.execute('select count(*) from progress')
+ row = cursor.fetchone()
+ if row is None:
+ raise ResumeError('Cannot retrieve progress information from database.')
+
+ return row[0] != 0
+
+ def StoreKeys(self, key_start, key_end):
+ """Record a new progress record, returning a key for later updates.
+
+ The specified progress information will be persisted into the database.
+ A unique key will be returned that identifies this progress state. The
+ key is later used to (quickly) update this record.
+
+ For the progress resumption to proceed properly, calls to StoreKeys
+ MUST specify monotonically increasing key ranges. This will result in
+ a database whereby the ID, KEY_START, and KEY_END rows are all
+ increasing (rather than having ranges out of order).
+
+ NOTE: the above precondition is NOT tested by this method (since it
+ would imply an additional table read or two on each invocation).
+
+ Args:
+ key_start: The starting key of the WorkItem (inclusive)
+ key_end: The end key of the WorkItem (inclusive)
+
+ Returns:
+ A string to later be used as a unique key to update this state.
+ """
+ self._OpenSecondaryConnection()
+
+ assert _RunningInThread(self.secondary_thread)
+ assert (not key_start) or isinstance(key_start, self.py_type), (
+ '%s is a %s, %s expected %s' % (key_start,
+ key_start.__class__,
+ self.__class__.__name__,
+ self.py_type))
+ assert (not key_end) or isinstance(key_end, self.py_type), (
+ '%s is a %s, %s expected %s' % (key_end,
+ key_end.__class__,
+ self.__class__.__name__,
+ self.py_type))
+ assert KeyLEQ(key_start, key_end), '%s not less than %s' % (
+ repr(key_start), repr(key_end))
+
+ self.insert_cursor.execute(
+ 'insert into progress (state, key_start, key_end) values (?, ?, ?)',
+ (STATE_READ, unicode(key_start), unicode(key_end)))
+
+ progress_key = self.insert_cursor.lastrowid
+
+ self._MaybeCommit()
+
+ return progress_key
+
+ def UpdateState(self, key, new_state):
+ """Update a specified progress record with new information.
+
+ Args:
+ key: The key for this progress record, returned from StoreKeys
+ new_state: The new state to associate with this progress record.
+ """
+ self._OpenSecondaryConnection()
+
+ assert _RunningInThread(self.secondary_thread)
+ assert isinstance(new_state, int)
+
+ self.update_cursor.execute('update progress set state=? where id=?',
+ (new_state, key))
+
+ self._MaybeCommit()
+
+ def DeleteKey(self, progress_key):
+ """Delete the entities with the given key from the result database."""
+ self._OpenSecondaryConnection()
+
+ assert _RunningInThread(self.secondary_thread)
+
+ t = time.time()
+ self.insert_cursor.execute(
+ 'delete from progress where rowid = ?', (progress_key,))
+
+ logger.debug('delete: delta=%.3f', time.time() - t)
+
+ self._MaybeCommit()
+
+ def GetProgressStatusGenerator(self):
+ """Get a generator which yields progress information.
+
+ The returned generator will yield a series of 4-tuples that specify
+ progress information about a prior run of the uploader. The 4-tuples
+ have the following values:
+
+ progress_key: The unique key to later update this record with new
+ progress information.
+ state: The last state saved for this progress record.
+ key_start: The starting key of the items for uploading (inclusive).
+ key_end: The ending key of the items for uploading (inclusive).
+
+ After all incompletely-transferred records are provided, then one
+ more 4-tuple will be generated:
+
+ None
+ DATA_CONSUMED_TO_HERE: A unique string value indicating this record
+ is being provided.
+ None
+ key_end: An integer value specifying the last data source key that
+ was handled by the previous run of the uploader.
+
+ The caller should begin uploading records which occur after key_end.
+
+ Yields:
+ Four-tuples of (progress_key, state, key_start, key_end)
+ """
+ conn = sqlite3.connect(self.db_filename, isolation_level=None)
+ cursor = conn.cursor()
+
+ cursor.execute('select max(key_end) from progress')
+
+ result = cursor.fetchone()
+ if result is not None:
+ key_end = result[0]
+ else:
+ logger.debug('No rows in progress database.')
+ return
+
+ self.prior_key_end = key_end
+
+ cursor.execute(
+ 'select id, state, key_start, key_end from progress'
+ ' where state != ?'
+ ' order by id',
+ (STATE_SENT,))
+
+ rows = cursor.fetchall()
+
+ for row in rows:
+ if row is None:
+ break
+ progress_key, state, key_start, key_end = row
+
+ yield progress_key, state, key_start, key_end
+
+ yield None, DATA_CONSUMED_TO_HERE, None, key_end
+
+
+def ProgressDatabase(db_filename, signature):
+ """Returns a database to store upload progress information."""
+ return _ProgressDatabase(db_filename, 'INTEGER', int, signature)
+
+
+class ExportProgressDatabase(_ProgressDatabase):
+ """A database to store download progress information."""
+
+ def __init__(self, db_filename, signature):
+ """Initialize an ExportProgressDatabase."""
+ _ProgressDatabase.__init__(self,
+ db_filename,
+ 'TEXT',
+ datastore.Key,
+ signature,
+ commit_periodicity=1)
+
+ def UseProgressData(self):
+ """Check if the progress database contains progress data.
+
+ Returns:
+ True: if the database contains progress data.
+ """
+ return self.existing_table
+
+
+class StubProgressDatabase(object):
+ """A stub implementation of ProgressDatabase which does nothing."""
+
+ def UseProgressData(self):
+ """Whether the stub database has progress information (it doesn't)."""
+ return False
+
+ def StoreKeys(self, unused_key_start, unused_key_end):
+ """Pretend to store a key in the stub database."""
+ return 'fake-key'
+
+ def UpdateState(self, unused_key, unused_new_state):
+ """Pretend to update the state of a progress item."""
+ pass
+
+ def ThreadComplete(self):
+ """Finalize operations on the stub database (i.e. do nothing)."""
+ pass
+
+
+class _ProgressThreadBase(_ThreadBase):
+ """A thread which records progress information for the upload process.
+
+ The progress information is stored into the provided progress database.
+ This class is not responsible for replaying a prior run's progress
+ information out of the database. Separate mechanisms must be used to
+ resume a prior upload attempt.
+ """
+
+ NAME = 'progress tracking thread'
+
+ def __init__(self, progress_queue, progress_db):
+ """Initialize the ProgressTrackerThread instance.
+
+ Args:
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ """
+ _ThreadBase.__init__(self)
+
+ self.progress_queue = progress_queue
+ self.db = progress_db
+ self.entities_transferred = 0
+
+ def EntitiesTransferred(self):
+ """Return the total number of unique entities transferred."""
+ return self.entities_transferred
+
+ def UpdateProgress(self, item):
+ """Updates the progress information for the given item.
+
+ Args:
+ item: A work item whose new state will be recorded
+ """
+ raise NotImplementedError()
+
+ def WorkFinished(self):
+ """Performs final actions after the entity transfer is complete."""
+ raise NotImplementedError()
+
+ def PerformWork(self):
+ """Performs the work of a ProgressTrackerThread."""
+ while not self.exit_flag:
+ try:
+ item = self.progress_queue.get(block=True, timeout=1.0)
+ except Queue.Empty:
+ continue
+ if item == _THREAD_SHOULD_EXIT:
+ break
+
+ if item.state == STATE_READ and item.progress_key is None:
+ item.progress_key = self.db.StoreKeys(item.key_start, item.key_end)
+ else:
+ assert item.progress_key is not None
+ self.UpdateProgress(item)
+
+ item.progress_event.set()
+
+ self.progress_queue.task_done()
+
+ self.db.ThreadComplete()
+
+
+
+class ProgressTrackerThread(_ProgressThreadBase):
+ """A thread which records progress information for the upload process.
+
+ The progress information is stored into the provided progress database.
+ This class is not responsible for replaying a prior run's progress
+ information out of the database. Separate mechanisms must be used to
+ resume a prior upload attempt.
+ """
+ NAME = 'progress tracking thread'
+
+ def __init__(self, progress_queue, progress_db):
+ """Initialize the ProgressTrackerThread instance.
+
+ Args:
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ """
+ _ProgressThreadBase.__init__(self, progress_queue, progress_db)
+
+ def UpdateProgress(self, item):
+ """Update the state of the given WorkItem.
+
+ Args:
+ item: A WorkItem instance.
+ """
+ self.db.UpdateState(item.progress_key, item.state)
+ if item.state == STATE_SENT:
+ self.entities_transferred += item.count
+
+ def WorkFinished(self):
+ """Performs final actions after the entity transfer is complete."""
+ pass
+
+
+class ExportProgressThread(_ProgressThreadBase):
+ """A thread to record progress information and write record data for exports.
+
+ The progress information is stored into a provided progress database.
+ Exported results are stored in the result database and dumped to an output
+ file at the end of the download.
+ """
+
+ def __init__(self, kind, progress_queue, progress_db, result_db):
+ """Initialize the ExportProgressThread instance.
+
+ Args:
+ kind: The kind of entities being stored in the database.
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ result_db: The database for holding exported entities; should be an
+ instance of ResultDatabase.
+ """
+ _ProgressThreadBase.__init__(self, progress_queue, progress_db)
+
+ self.kind = kind
+ self.existing_count = result_db.existing_count
+ self.result_db = result_db
+
+ def EntitiesTransferred(self):
+ """Return the total number of unique entities transferred."""
+ return self.result_db.count
+
+ def WorkFinished(self):
+ """Write the contents of the result database."""
+ exporter = Exporter.RegisteredExporter(self.kind)
+ exporter.output_entities(self.result_db.AllEntities())
+
+ def UpdateProgress(self, item):
+ """Update the state of the given KeyRangeItem.
+
+ Args:
+ item: A KeyRange instance.
+ """
+ if item.state == STATE_GOT:
+ count = self.result_db.StoreEntities(item.download_result.keys,
+ item.download_result.entities)
+ self.db.DeleteKey(item.progress_key)
+ self.entities_transferred += count
+ else:
+ self.db.UpdateState(item.progress_key, item.state)
+
+
+class MapperProgressThread(_ProgressThreadBase):
+ """A thread to record progress information for maps over the datastore."""
+
+ def __init__(self, kind, progress_queue, progress_db):
+ """Initialize the MapperProgressThread instance.
+
+ Args:
+ kind: The kind of entities being stored in the database.
+ progress_queue: A Queue used for tracking progress information.
+ progress_db: The database for tracking progress information; should
+ be an instance of ProgressDatabase.
+ """
+ _ProgressThreadBase.__init__(self, progress_queue, progress_db)
+
+ self.kind = kind
+ self.mapper = Mapper.RegisteredMapper(self.kind)
+
+ def EntitiesTransferred(self):
+ """Return the total number of unique entities transferred."""
+ return self.entities_transferred
+
+ def WorkFinished(self):
+ """Perform actions after map is complete."""
+ pass
+
+ def UpdateProgress(self, item):
+ """Update the state of the given KeyRangeItem.
+
+ Args:
+ item: A KeyRange instance.
+ """
+ if item.state == STATE_GOT:
+ self.entities_transferred += item.count
+ self.db.DeleteKey(item.progress_key)
+ else:
+ self.db.UpdateState(item.progress_key, item.state)
+
+
+def ParseKey(key_string):
+ """Turn a key stored in the database into a Key or None.
+
+ Args:
+ key_string: The string representation of a Key.
+
+ Returns:
+ A datastore.Key instance or None
+ """
+ if not key_string:
+ return None
+ if key_string == 'None':
+ return None
+ return datastore.Key(encoded=key_string)
+
+
+def Validate(value, typ):
+ """Checks that value is non-empty and of the right type.
+
+ Args:
+ value: any value
+ typ: a type or tuple of types
+
+ Raises:
+ ValueError: if value is None or empty.
+ TypeError: if it's not the given type.
+ """
+ if not value:
+ raise ValueError('Value should not be empty; received %s.' % value)
+ elif not isinstance(value, typ):
+ raise TypeError('Expected a %s, but received %s (a %s).' %
+ (typ, value, value.__class__))
+
+
+def CheckFile(filename):
+ """Check that the given file exists and can be opened for reading.
+
+ Args:
+ filename: The name of the file.
+
+ Raises:
+ FileNotFoundError: if the given filename is not found
+ FileNotReadableError: if the given filename is not readable.
+ """
+ if not os.path.exists(filename):
+ raise FileNotFoundError('%s: file not found' % filename)
+ elif not os.access(filename, os.R_OK):
+ raise FileNotReadableError('%s: file not readable' % filename)
+
+
+class Loader(object):
+ """A base class for creating datastore entities from input data.
+
+ To add a handler for bulk loading a new entity kind into your datastore,
+ write a subclass of this class that calls Loader.__init__ from your
+ class's __init__.
+
+ If you need to run extra code to convert entities from the input
+ data, create new properties, or otherwise modify the entities before
+ they're inserted, override handle_entity.
+
+ See the create_entity method for the creation of entities from the
+ (parsed) input data.
+ """
+
+ __loaders = {}
+ kind = None
+ __properties = None
+
+ def __init__(self, kind, properties):
+ """Constructor.
+
+ Populates this Loader's kind and properties map.
+
+ Args:
+ kind: a string containing the entity kind that this loader handles
+
+ properties: list of (name, converter) tuples.
+
+ This is used to automatically convert the input columns into
+ properties. The converter should be a function that takes one
+ argument, a string value from the input file, and returns a
+ correctly typed property value that should be inserted. The
+ tuples in this list should match the columns in your input file,
+ in order.
+
+ For example:
+ [('name', str),
+ ('id_number', int),
+ ('email', datastore_types.Email),
+ ('user', users.User),
+ ('birthdate', lambda x: datetime.datetime.fromtimestamp(float(x))),
+ ('description', datastore_types.Text),
+ ]
+ """
+ Validate(kind, (basestring, tuple))
+ self.kind = kind
+ self.__openfile = open
+ self.__create_csv_reader = csv.reader
+
+ GetImplementationClass(kind)
+
+ Validate(properties, list)
+ for name, fn in properties:
+ Validate(name, basestring)
+ assert callable(fn), (
+ 'Conversion function %s for property %s is not callable.' % (fn, name))
+
+ self.__properties = properties
+
+ @staticmethod
+ def RegisterLoader(loader):
+ """Register loader and the Loader instance for its kind.
+
+ Args:
+ loader: A Loader instance.
+ """
+ Loader.__loaders[loader.kind] = loader
+
+ def alias_old_names(self):
+ """Aliases method names so that Loaders defined with old names work."""
+ aliases = (
+ ('CreateEntity', 'create_entity'),
+ ('HandleEntity', 'handle_entity'),
+ ('GenerateKey', 'generate_key'),
+ )
+ for old_name, new_name in aliases:
+ setattr(Loader, old_name, getattr(Loader, new_name))
+ if hasattr(self.__class__, old_name) and not (
+ getattr(self.__class__, old_name).im_func ==
+ getattr(Loader, new_name).im_func):
+ if hasattr(self.__class__, new_name) and not (
+ getattr(self.__class__, new_name).im_func ==
+ getattr(Loader, new_name).im_func):
+ raise NameClashError(old_name, new_name, self.__class__)
+ setattr(self, new_name, getattr(self, old_name))
+
+ def create_entity(self, values, key_name=None, parent=None):
+ """Creates a entity from a list of property values.
+
+ Args:
+ values: list/tuple of str
+ key_name: if provided, the name for the (single) resulting entity
+ parent: A datastore.Key instance for the parent, or None
+
+ Returns:
+ list of db.Model
+
+ The returned entities are populated with the property values from the
+ argument, converted to native types using the properties map given in
+ the constructor, and passed through handle_entity. They're ready to be
+ inserted.
+
+ Raises:
+ AssertionError: if the number of values doesn't match the number
+ of properties in the properties map.
+ ValueError: if any element of values is None or empty.
+ TypeError: if values is not a list or tuple.
+ """
+ Validate(values, (list, tuple))
+ assert len(values) == len(self.__properties), (
+ 'Expected %d columns, found %d.' %
+ (len(self.__properties), len(values)))
+
+ model_class = GetImplementationClass(self.kind)
+
+ properties = {
+ 'key_name': key_name,
+ 'parent': parent,
+ }
+ for (name, converter), val in zip(self.__properties, values):
+ if converter is bool and val.lower() in ('0', 'false', 'no'):
+ val = False
+ properties[name] = converter(val)
+
+ entity = model_class(**properties)
+ entities = self.handle_entity(entity)
+
+ if entities:
+ if not isinstance(entities, (list, tuple)):
+ entities = [entities]
+
+ for entity in entities:
+ if not isinstance(entity, db.Model):
+ raise TypeError('Expected a db.Model, received %s (a %s).' %
+ (entity, entity.__class__))
+
+ return entities
+
+ def generate_key(self, i, values):
+ """Generates a key_name to be used in creating the underlying object.
+
+ The default implementation returns None.
+
+ This method can be overridden to control the key generation for
+ uploaded entities. The value returned should be None (to use a
+ server generated numeric key), or a string which neither starts
+ with a digit nor has the form __*__ (see
+ http://code.google.com/appengine/docs/python/datastore/keysandentitygroups.html),
+ or a datastore.Key instance.
+
+ If you generate your own string keys, keep in mind:
+
+ 1. The key name for each entity must be unique.
+ 2. If an entity of the same kind and key already exists in the
+ datastore, it will be overwritten.
+
+ Args:
+ i: Number corresponding to this object (assume it's run in a loop,
+ this is your current count.
+ values: list/tuple of str.
+
+ Returns:
+ A string to be used as the key_name for an entity.
+ """
+ return None
+
+ def handle_entity(self, entity):
+ """Subclasses can override this to add custom entity conversion code.
+
+ This is called for each entity, after its properties are populated
+ from the input but before it is stored. Subclasses can override
+ this to add custom entity handling code.
+
+ The entity to be inserted should be returned. If multiple entities
+ should be inserted, return a list of entities. If no entities
+ should be inserted, return None or [].
+
+ Args:
+ entity: db.Model
+
+ Returns:
+ db.Model or list of db.Model
+ """
+ return entity
+
+ def initialize(self, filename, loader_opts):
+ """Performs initialization and validation of the input file.
+
+ This implementation checks that the input file exists and can be
+ opened for reading.
+
+ Args:
+ filename: The string given as the --filename flag argument.
+ loader_opts: The string given as the --loader_opts flag argument.
+ """
+ CheckFile(filename)
+
+ def finalize(self):
+ """Performs finalization actions after the upload completes."""
+ pass
+
+ def generate_records(self, filename):
+ """Subclasses can override this to add custom data input code.
+
+ This method must yield fixed-length lists of strings.
+
+ The default implementation uses csv.reader to read CSV rows
+ from filename.
+
+ Args:
+ filename: The string input for the --filename option.
+
+ Yields:
+ Lists of strings.
+ """
+ csv_generator = CSVGenerator(filename, openfile=self.__openfile,
+ create_csv_reader=self.__create_csv_reader
+ ).Records()
+ return csv_generator
+
+ @staticmethod
+ def RegisteredLoaders():
+ """Returns a dict of the Loader instances that have been created."""
+ return dict(Loader.__loaders)
+
+ @staticmethod
+ def RegisteredLoader(kind):
+ """Returns the loader instance for the given kind if it exists."""
+ return Loader.__loaders[kind]
+
+
+class RestoreThread(_ThreadBase):
+ """A thread to read saved entity_pbs from sqlite3."""
+ NAME = 'RestoreThread'
+ _ENTITIES_DONE = 'Entities Done'
+
+ def __init__(self, queue, filename):
+ _ThreadBase.__init__(self)
+ self.queue = queue
+ self.filename = filename
+
+ def PerformWork(self):
+ db_conn = sqlite3.connect(self.filename)
+ cursor = db_conn.cursor()
+ cursor.execute('select id, value from result')
+ for entity_id, value in cursor:
+ self.queue.put([entity_id, value], block=True)
+ self.queue.put(RestoreThread._ENTITIES_DONE, block=True)
+
+
+class RestoreLoader(Loader):
+ """A Loader which imports protobuffers from a file."""
+
+ def __init__(self, kind):
+ self.kind = kind
+
+ def initialize(self, filename, loader_opts):
+ CheckFile(filename)
+ self.queue = Queue.Queue(1000)
+ restore_thread = RestoreThread(self.queue, filename)
+ restore_thread.start()
+
+ def generate_records(self, filename):
+ while True:
+ record = self.queue.get(block=True)
+ if id(record) == id(RestoreThread._ENTITIES_DONE):
+ break
+ yield record
+
+ def create_entity(self, values, key_name=None, parent=None):
+ key = StrKey(unicode(values[0], 'utf-8'))
+ entity_proto = entity_pb.EntityProto(contents=str(values[1]))
+ entity_proto.mutable_key().CopyFrom(key._Key__reference)
+ return datastore.Entity._FromPb(entity_proto)
+
+
+class Exporter(object):
+ """A base class for serializing datastore entities.
+
+ To add a handler for exporting an entity kind from your datastore,
+ write a subclass of this class that calls Exporter.__init__ from your
+ class's __init__.
+
+ If you need to run extra code to convert entities from the input
+ data, create new properties, or otherwise modify the entities before
+ they're inserted, override handle_entity.
+
+ See the output_entities method for the writing of data from entities.
+ """
+
+ __exporters = {}
+ kind = None
+ __properties = None
+
+ def __init__(self, kind, properties):
+ """Constructor.
+
+ Populates this Exporters's kind and properties map.
+
+ Args:
+ kind: a string containing the entity kind that this exporter handles
+
+ properties: list of (name, converter, default) tuples.
+
+ This is used to automatically convert the entities to strings.
+ The converter should be a function that takes one argument, a property
+ value of the appropriate type, and returns a str or unicode. The default
+ is a string to be used if the property is not present, or None to fail
+ with an error if the property is missing.
+
+ For example:
+ [('name', str, None),
+ ('id_number', str, None),
+ ('email', str, ''),
+ ('user', str, None),
+ ('birthdate',
+ lambda x: str(datetime.datetime.fromtimestamp(float(x))),
+ None),
+ ('description', str, ''),
+ ]
+ """
+ Validate(kind, basestring)
+ self.kind = kind
+
+ GetImplementationClass(kind)
+
+ Validate(properties, list)
+ for name, fn, default in properties:
+ Validate(name, basestring)
+ assert callable(fn), (
+ 'Conversion function %s for property %s is not callable.' % (
+ fn, name))
+ if default:
+ Validate(default, basestring)
+
+ self.__properties = properties
+
+ @staticmethod
+ def RegisterExporter(exporter):
+ """Register exporter and the Exporter instance for its kind.
+
+ Args:
+ exporter: A Exporter instance.
+ """
+ Exporter.__exporters[exporter.kind] = exporter
+
+ def __ExtractProperties(self, entity):
+ """Converts an entity into a list of string values.
+
+ Args:
+ entity: An entity to extract the properties from.
+
+ Returns:
+ A list of the properties of the entity.
+
+ Raises:
+ MissingPropertyError: if an expected field on the entity is missing.
+ """
+ encoding = []
+ for name, fn, default in self.__properties:
+ try:
+ encoding.append(fn(entity[name]))
+ except AttributeError:
+ if default is None:
+ raise MissingPropertyError(name)
+ else:
+ encoding.append(default)
+ return encoding
+
+ def __EncodeEntity(self, entity):
+ """Convert the given entity into CSV string.
+
+ Args:
+ entity: The entity to encode.
+
+ Returns:
+ A CSV string.
+ """
+ output = StringIO.StringIO()
+ writer = csv.writer(output, lineterminator='')
+ writer.writerow(self.__ExtractProperties(entity))
+ return output.getvalue()
+
+ def __SerializeEntity(self, entity):
+ """Creates a string representation of an entity.
+
+ Args:
+ entity: The entity to serialize.
+
+ Returns:
+ A serialized representation of an entity.
+ """
+ encoding = self.__EncodeEntity(entity)
+ if not isinstance(encoding, unicode):
+ encoding = unicode(encoding, 'utf-8')
+ encoding = encoding.encode('utf-8')
+ return encoding
+
+ def output_entities(self, entity_generator):
+ """Outputs the downloaded entities.
+
+ This implementation writes CSV.
+
+ Args:
+ entity_generator: A generator that yields the downloaded entities
+ in key order.
+ """
+ CheckOutputFile(self.output_filename)
+ output_file = open(self.output_filename, 'w')
+ logger.debug('Export complete, writing to file')
+ output_file.writelines(self.__SerializeEntity(entity) + '\n'
+ for entity in entity_generator)
+
+ def initialize(self, filename, exporter_opts):
+ """Performs initialization and validation of the output file.
+
+ This implementation checks that the input file exists and can be
+ opened for writing.
+
+ Args:
+ filename: The string given as the --filename flag argument.
+ exporter_opts: The string given as the --exporter_opts flag argument.
+ """
+ CheckOutputFile(filename)
+ self.output_filename = filename
+
+ def finalize(self):
+ """Performs finalization actions after the download completes."""
+ pass
+
+ @staticmethod
+ def RegisteredExporters():
+ """Returns a dictionary of the exporter instances that have been created."""
+ return dict(Exporter.__exporters)
+
+ @staticmethod
+ def RegisteredExporter(kind):
+ """Returns an exporter instance for the given kind if it exists."""
+ return Exporter.__exporters[kind]
+
+
+class DumpExporter(Exporter):
+ """An exporter which dumps protobuffers to a file."""
+
+ def __init__(self, kind, result_db_filename):
+ self.kind = kind
+ self.result_db_filename = result_db_filename
+
+ def output_entities(self, entity_generator):
+ shutil.copyfile(self.result_db_filename, self.output_filename)
+
+
+class MapperRetry(Error):
+ """An exception that indicates a non-fatal error during mapping."""
+
+
+class Mapper(object):
+ """A base class for serializing datastore entities.
+
+ To add a handler for exporting an entity kind from your datastore,
+ write a subclass of this class that calls Mapper.__init__ from your
+ class's __init__.
+
+ You need to implement to batch_apply or apply method on your subclass
+ for the map to do anything.
+ """
+
+ __mappers = {}
+ kind = None
+
+ def __init__(self, kind):
+ """Constructor.
+
+ Populates this Mappers's kind.
+
+ Args:
+ kind: a string containing the entity kind that this mapper handles
+ """
+ Validate(kind, basestring)
+ self.kind = kind
+
+ GetImplementationClass(kind)
+
+ @staticmethod
+ def RegisterMapper(mapper):
+ """Register mapper and the Mapper instance for its kind.
+
+ Args:
+ mapper: A Mapper instance.
+ """
+ Mapper.__mappers[mapper.kind] = mapper
+
+ def initialize(self, mapper_opts):
+ """Performs initialization.
+
+ Args:
+ mapper_opts: The string given as the --mapper_opts flag argument.
+ """
+ pass
+
+ def finalize(self):
+ """Performs finalization actions after the download completes."""
+ pass
+
+ def apply(self, entity):
+ print 'Default map function doing nothing to %s' % entity
+
+ def batch_apply(self, entities):
+ for entity in entities:
+ self.apply(entity)
+
+ @staticmethod
+ def RegisteredMappers():
+ """Returns a dictionary of the mapper instances that have been created."""
+ return dict(Mapper.__mappers)
+
+ @staticmethod
+ def RegisteredMapper(kind):
+ """Returns an mapper instance for the given kind if it exists."""
+ return Mapper.__mappers[kind]
+
+
+class QueueJoinThread(threading.Thread):
+ """A thread that joins a queue and exits.
+
+ Queue joins do not have a timeout. To simulate a queue join with
+ timeout, run this thread and join it with a timeout.
+ """
+
+ def __init__(self, queue):
+ """Initialize a QueueJoinThread.
+
+ Args:
+ queue: The queue for this thread to join.
+ """
+ threading.Thread.__init__(self)
+ assert isinstance(queue, (Queue.Queue, ReQueue))
+ self.queue = queue
+
+ def run(self):
+ """Perform the queue join in this thread."""
+ self.queue.join()
+
+
+def InterruptibleQueueJoin(queue,
+ thread_local,
+ thread_pool,
+ queue_join_thread_factory=QueueJoinThread,
+ check_workers=True):
+ """Repeatedly joins the given ReQueue or Queue.Queue with short timeout.
+
+ Between each timeout on the join, worker threads are checked.
+
+ Args:
+ queue: A Queue.Queue or ReQueue instance.
+ thread_local: A threading.local instance which indicates interrupts.
+ thread_pool: An AdaptiveThreadPool instance.
+ queue_join_thread_factory: Used for dependency injection.
+ check_workers: Whether to interrupt the join on worker death.
+
+ Returns:
+ True unless the queue join is interrupted by SIGINT or worker death.
+ """
+ thread = queue_join_thread_factory(queue)
+ thread.start()
+ while True:
+ thread.join(timeout=.5)
+ if not thread.isAlive():
+ return True
+ if thread_local.shut_down:
+ logger.debug('Queue join interrupted')
+ return False
+ if check_workers:
+ for worker_thread in thread_pool.Threads():
+ if not worker_thread.isAlive():
+ return False
+
+
+def ShutdownThreads(data_source_thread, thread_pool):
+ """Shuts down the worker and data source threads.
+
+ Args:
+ data_source_thread: A running DataSourceThread instance.
+ thread_pool: An AdaptiveThreadPool instance with workers registered.
+ """
+ logger.info('An error occurred. Shutting down...')
+
+ data_source_thread.exit_flag = True
+
+ thread_pool.Shutdown()
+
+ data_source_thread.join(timeout=3.0)
+ if data_source_thread.isAlive():
+ logger.warn('%s hung while trying to exit',
+ data_source_thread.GetFriendlyName())
+
+
+class BulkTransporterApp(object):
+ """Class to wrap bulk transport application functionality."""
+
+ def __init__(self,
+ arg_dict,
+ input_generator_factory,
+ throttle,
+ progress_db,
+ progresstrackerthread_factory,
+ max_queue_size=DEFAULT_QUEUE_SIZE,
+ request_manager_factory=RequestManager,
+ datasourcethread_factory=DataSourceThread,
+ progress_queue_factory=Queue.Queue,
+ thread_pool_factory=adaptive_thread_pool.AdaptiveThreadPool):
+ """Instantiate a BulkTransporterApp.
+
+ Uploads or downloads data to or from application using HTTP requests.
+ When run, the class will spin up a number of threads to read entities
+ from the data source, pass those to a number of worker threads
+ for sending to the application, and track all of the progress in a
+ small database in case an error or pause/termination requires a
+ restart/resumption of the upload process.
+
+ Args:
+ arg_dict: Dictionary of command line options.
+ input_generator_factory: A factory that creates a WorkItem generator.
+ throttle: A Throttle instance.
+ progress_db: The database to use for replaying/recording progress.
+ progresstrackerthread_factory: Used for dependency injection.
+ max_queue_size: Maximum size of the queues before they should block.
+ request_manager_factory: Used for dependency injection.
+ datasourcethread_factory: Used for dependency injection.
+ progress_queue_factory: Used for dependency injection.
+ thread_pool_factory: Used for dependency injection.
+ """
+ self.app_id = arg_dict['app_id']
+ self.post_url = arg_dict['url']
+ self.kind = arg_dict['kind']
+ self.batch_size = arg_dict['batch_size']
+ self.input_generator_factory = input_generator_factory
+ self.num_threads = arg_dict['num_threads']
+ self.email = arg_dict['email']
+ self.passin = arg_dict['passin']
+ self.dry_run = arg_dict['dry_run']
+ self.throttle = throttle
+ self.progress_db = progress_db
+ self.progresstrackerthread_factory = progresstrackerthread_factory
+ self.max_queue_size = max_queue_size
+ self.request_manager_factory = request_manager_factory
+ self.datasourcethread_factory = datasourcethread_factory
+ self.progress_queue_factory = progress_queue_factory
+ self.thread_pool_factory = thread_pool_factory
+ (scheme,
+ self.host_port, self.url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(self.post_url)
+ self.secure = (scheme == 'https')
+
+ def Run(self):
+ """Perform the work of the BulkTransporterApp.
+
+ Raises:
+ AuthenticationError: If authentication is required and fails.
+
+ Returns:
+ Error code suitable for sys.exit, e.g. 0 on success, 1 on failure.
+ """
+ self.error = False
+ thread_pool = self.thread_pool_factory(
+ self.num_threads, queue_size=self.max_queue_size)
+
+ self.throttle.Register(threading.currentThread())
+ threading.currentThread().exit_flag = False
+
+ progress_queue = self.progress_queue_factory(self.max_queue_size)
+ request_manager = self.request_manager_factory(self.app_id,
+ self.host_port,
+ self.url_path,
+ self.kind,
+ self.throttle,
+ self.batch_size,
+ self.secure,
+ self.email,
+ self.passin,
+ self.dry_run)
+ try:
+ request_manager.Authenticate()
+ except Exception, e:
+ self.error = True
+ if not isinstance(e, urllib2.HTTPError) or (
+ e.code != 302 and e.code != 401):
+ logger.exception('Exception during authentication')
+ raise AuthenticationError()
+ if (request_manager.auth_called and
+ not request_manager.authenticated):
+ self.error = True
+ raise AuthenticationError('Authentication failed')
+
+ for thread in thread_pool.Threads():
+ self.throttle.Register(thread)
+
+ self.progress_thread = self.progresstrackerthread_factory(
+ progress_queue, self.progress_db)
+
+ if self.progress_db.UseProgressData():
+ logger.debug('Restarting upload using progress database')
+ progress_generator_factory = self.progress_db.GetProgressStatusGenerator
+ else:
+ progress_generator_factory = None
+
+ self.data_source_thread = (
+ self.datasourcethread_factory(request_manager,
+ thread_pool,
+ progress_queue,
+ self.input_generator_factory,
+ progress_generator_factory))
+
+ thread_local = threading.local()
+ thread_local.shut_down = False
+
+ def Interrupt(unused_signum, unused_frame):
+ """Shutdown gracefully in response to a signal."""
+ thread_local.shut_down = True
+ self.error = True
+
+ signal.signal(signal.SIGINT, Interrupt)
+
+ self.progress_thread.start()
+ self.data_source_thread.start()
+
+
+ while not thread_local.shut_down:
+ self.data_source_thread.join(timeout=0.25)
+
+ if self.data_source_thread.isAlive():
+ for thread in list(thread_pool.Threads()) + [self.progress_thread]:
+ if not thread.isAlive():
+ logger.info('Unexpected thread death: %s', thread.getName())
+ thread_local.shut_down = True
+ self.error = True
+ break
+ else:
+ break
+
+ def _Join(ob, msg):
+ logger.debug('Waiting for %s...', msg)
+ if isinstance(ob, threading.Thread):
+ ob.join(timeout=3.0)
+ if ob.isAlive():
+ logger.debug('Joining %s failed', ob)
+ else:
+ logger.debug('... done.')
+ elif isinstance(ob, (Queue.Queue, ReQueue)):
+ if not InterruptibleQueueJoin(ob, thread_local, thread_pool):
+ ShutdownThreads(self.data_source_thread, thread_pool)
+ else:
+ ob.join()
+ logger.debug('... done.')
+
+ if self.data_source_thread.error or thread_local.shut_down:
+ ShutdownThreads(self.data_source_thread, thread_pool)
+ else:
+ _Join(thread_pool.requeue, 'worker threads to finish')
+
+ thread_pool.Shutdown()
+ thread_pool.JoinThreads()
+ thread_pool.CheckErrors()
+ print ''
+
+ if self.progress_thread.isAlive():
+ InterruptibleQueueJoin(progress_queue, thread_local, thread_pool,
+ check_workers=False)
+ else:
+ logger.warn('Progress thread exited prematurely')
+
+ progress_queue.put(_THREAD_SHOULD_EXIT)
+ _Join(self.progress_thread, 'progress_thread to terminate')
+ self.progress_thread.CheckError()
+ if not thread_local.shut_down:
+ self.progress_thread.WorkFinished()
+
+ self.data_source_thread.CheckError()
+
+ return self.ReportStatus()
+
+ def ReportStatus(self):
+ """Display a message reporting the final status of the transfer."""
+ raise NotImplementedError()
+
+
+class BulkUploaderApp(BulkTransporterApp):
+ """Class to encapsulate bulk uploader functionality."""
+
+ def __init__(self, *args, **kwargs):
+ BulkTransporterApp.__init__(self, *args, **kwargs)
+
+ def ReportStatus(self):
+ """Display a message reporting the final status of the transfer."""
+ total_up, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_UP)
+ s_total_up, unused_duration = self.throttle.TotalTransferred(
+ remote_api_throttle.HTTPS_BANDWIDTH_UP)
+ total_up += s_total_up
+ total = total_up
+ logger.info('%d entites total, %d previously transferred',
+ self.data_source_thread.read_count,
+ self.data_source_thread.xfer_count)
+ transfer_count = self.progress_thread.EntitiesTransferred()
+ logger.info('%d entities (%d bytes) transferred in %.1f seconds',
+ transfer_count, total, duration)
+ if (self.data_source_thread.read_all and
+ transfer_count +
+ self.data_source_thread.xfer_count >=
+ self.data_source_thread.read_count):
+ logger.info('All entities successfully transferred')
+ return 0
+ else:
+ logger.info('Some entities not successfully transferred')
+ return 1
+
+
+class BulkDownloaderApp(BulkTransporterApp):
+ """Class to encapsulate bulk downloader functionality."""
+
+ def __init__(self, *args, **kwargs):
+ BulkTransporterApp.__init__(self, *args, **kwargs)
+
+ def ReportStatus(self):
+ """Display a message reporting the final status of the transfer."""
+ total_down, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_DOWN)
+ s_total_down, unused_duration = self.throttle.TotalTransferred(
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
+ total_down += s_total_down
+ total = total_down
+ existing_count = self.progress_thread.existing_count
+ xfer_count = self.progress_thread.EntitiesTransferred()
+ logger.info('Have %d entities, %d previously transferred',
+ xfer_count, existing_count)
+ logger.info('%d entities (%d bytes) transferred in %.1f seconds',
+ xfer_count, total, duration)
+ if self.error:
+ return 1
+ else:
+ return 0
+
+
+class BulkMapperApp(BulkTransporterApp):
+ """Class to encapsulate bulk map functionality."""
+
+ def __init__(self, *args, **kwargs):
+ BulkTransporterApp.__init__(self, *args, **kwargs)
+
+ def ReportStatus(self):
+ """Display a message reporting the final status of the transfer."""
+ total_down, duration = self.throttle.TotalTransferred(
+ remote_api_throttle.BANDWIDTH_DOWN)
+ s_total_down, unused_duration = self.throttle.TotalTransferred(
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN)
+ total_down += s_total_down
+ total = total_down
+ xfer_count = self.progress_thread.EntitiesTransferred()
+ logger.info('The following may be inaccurate if any mapper tasks '
+ 'encountered errors and had to be retried.')
+ logger.info('Applied mapper to %s entities.',
+ xfer_count)
+ logger.info('%s entities (%s bytes) transferred in %.1f seconds',
+ xfer_count, total, duration)
+ if self.error:
+ return 1
+ else:
+ return 0
+
+
+def PrintUsageExit(code):
+ """Prints usage information and exits with a status code.
+
+ Args:
+ code: Status code to pass to sys.exit() after displaying usage information.
+ """
+ print __doc__ % {'arg0': sys.argv[0]}
+ sys.stdout.flush()
+ sys.stderr.flush()
+ sys.exit(code)
+
+
+REQUIRED_OPTION = object()
+
+
+FLAG_SPEC = ['debug',
+ 'help',
+ 'url=',
+ 'filename=',
+ 'batch_size=',
+ 'kind=',
+ 'num_threads=',
+ 'bandwidth_limit=',
+ 'rps_limit=',
+ 'http_limit=',
+ 'db_filename=',
+ 'app_id=',
+ 'config_file=',
+ 'has_header',
+ 'csv_has_header',
+ 'auth_domain=',
+ 'result_db_filename=',
+ 'download',
+ 'loader_opts=',
+ 'exporter_opts=',
+ 'log_file=',
+ 'mapper_opts=',
+ 'email=',
+ 'passin',
+ 'map',
+ 'dry_run',
+ 'dump',
+ 'restore',
+ ]
+
+
+def ParseArguments(argv, die_fn=lambda: PrintUsageExit(1)):
+ """Parses command-line arguments.
+
+ Prints out a help message if -h or --help is supplied.
+
+ Args:
+ argv: List of command-line arguments.
+ die_fn: Function to invoke to end the program.
+
+ Returns:
+ A dictionary containing the value of command-line options.
+ """
+ opts, unused_args = getopt.getopt(
+ argv[1:],
+ 'h',
+ FLAG_SPEC)
+
+ arg_dict = {}
+
+ arg_dict['url'] = REQUIRED_OPTION
+ arg_dict['filename'] = None
+ arg_dict['config_file'] = None
+ arg_dict['kind'] = None
+
+ arg_dict['batch_size'] = None
+ arg_dict['num_threads'] = DEFAULT_THREAD_COUNT
+ arg_dict['bandwidth_limit'] = DEFAULT_BANDWIDTH_LIMIT
+ arg_dict['rps_limit'] = DEFAULT_RPS_LIMIT
+ arg_dict['http_limit'] = DEFAULT_REQUEST_LIMIT
+
+ arg_dict['db_filename'] = None
+ arg_dict['app_id'] = ''
+ arg_dict['auth_domain'] = 'gmail.com'
+ arg_dict['has_header'] = False
+ arg_dict['result_db_filename'] = None
+ arg_dict['download'] = False
+ arg_dict['loader_opts'] = None
+ arg_dict['exporter_opts'] = None
+ arg_dict['debug'] = False
+ arg_dict['log_file'] = None
+ arg_dict['email'] = None
+ arg_dict['passin'] = False
+ arg_dict['mapper_opts'] = None
+ arg_dict['map'] = False
+ arg_dict['dry_run'] = False
+ arg_dict['dump'] = False
+ arg_dict['restore'] = False
+
+ def ExpandFilename(filename):
+ """Expand shell variables and ~usernames in filename."""
+ return os.path.expandvars(os.path.expanduser(filename))
+
+ for option, value in opts:
+ if option == '--debug':
+ arg_dict['debug'] = True
+ elif option in ('-h', '--help'):
+ PrintUsageExit(0)
+ elif option == '--url':
+ arg_dict['url'] = value
+ elif option == '--filename':
+ arg_dict['filename'] = ExpandFilename(value)
+ elif option == '--batch_size':
+ arg_dict['batch_size'] = int(value)
+ elif option == '--kind':
+ arg_dict['kind'] = value
+ elif option == '--num_threads':
+ arg_dict['num_threads'] = int(value)
+ elif option == '--bandwidth_limit':
+ arg_dict['bandwidth_limit'] = int(value)
+ elif option == '--rps_limit':
+ arg_dict['rps_limit'] = int(value)
+ elif option == '--http_limit':
+ arg_dict['http_limit'] = int(value)
+ elif option == '--db_filename':
+ arg_dict['db_filename'] = ExpandFilename(value)
+ elif option == '--app_id':
+ arg_dict['app_id'] = value
+ elif option == '--config_file':
+ arg_dict['config_file'] = ExpandFilename(value)
+ elif option == '--auth_domain':
+ arg_dict['auth_domain'] = value
+ elif option == '--has_header':
+ arg_dict['has_header'] = True
+ elif option == '--csv_has_header':
+ print >>sys.stderr, ('--csv_has_header is deprecated, please use '
+ '--has_header.')
+ arg_dict['has_header'] = True
+ elif option == '--result_db_filename':
+ arg_dict['result_db_filename'] = ExpandFilename(value)
+ elif option == '--download':
+ arg_dict['download'] = True
+ elif option == '--loader_opts':
+ arg_dict['loader_opts'] = value
+ elif option == '--exporter_opts':
+ arg_dict['exporter_opts'] = value
+ elif option == '--log_file':
+ arg_dict['log_file'] = ExpandFilename(value)
+ elif option == '--email':
+ arg_dict['email'] = value
+ elif option == '--passin':
+ arg_dict['passin'] = True
+ elif option == '--map':
+ arg_dict['map'] = True
+ elif option == '--mapper_opts':
+ arg_dict['mapper_opts'] = value
+ elif option == '--dry_run':
+ arg_dict['dry_run'] = True
+ elif option == '--dump':
+ arg_dict['dump'] = True
+ elif option == '--restore':
+ arg_dict['restore'] = True
+
+ return ProcessArguments(arg_dict, die_fn=die_fn)
+
+
+def ThrottleLayout(bandwidth_limit, http_limit, rps_limit):
+ """Return a dictionary indicating the throttle options."""
+ bulkloader_limits = dict(remote_api_throttle.NO_LIMITS)
+ bulkloader_limits.update({
+ remote_api_throttle.BANDWIDTH_UP: bandwidth_limit,
+ remote_api_throttle.BANDWIDTH_DOWN: bandwidth_limit,
+ remote_api_throttle.REQUESTS: http_limit,
+ remote_api_throttle.HTTPS_BANDWIDTH_UP: bandwidth_limit,
+ remote_api_throttle.HTTPS_BANDWIDTH_DOWN: bandwidth_limit,
+ remote_api_throttle.HTTPS_REQUESTS: http_limit,
+ remote_api_throttle.ENTITIES_FETCHED: rps_limit,
+ remote_api_throttle.ENTITIES_MODIFIED: rps_limit,
+ })
+ return bulkloader_limits
+
+
+def CheckOutputFile(filename):
+ """Check that the given file does not exist and can be opened for writing.
+
+ Args:
+ filename: The name of the file.
+
+ Raises:
+ FileExistsError: if the given filename is not found
+ FileNotWritableError: if the given filename is not readable.
+ """
+ full_path = os.path.abspath(filename)
+ if os.path.exists(full_path):
+ raise FileExistsError('%s: output file exists' % filename)
+ elif not os.access(os.path.dirname(full_path), os.W_OK):
+ raise FileNotWritableError(
+ '%s: not writable' % os.path.dirname(full_path))
+
+
+def LoadConfig(config_file_name, exit_fn=sys.exit):
+ """Loads a config file and registers any Loader classes present.
+
+ Args:
+ config_file_name: The name of the configuration file.
+ exit_fn: Used for dependency injection.
+ """
+ if config_file_name:
+ config_file = open(config_file_name, 'r')
+ try:
+ bulkloader_config = imp.load_module(
+ 'bulkloader_config', config_file, config_file_name,
+ ('', 'r', imp.PY_SOURCE))
+ sys.modules['bulkloader_config'] = bulkloader_config
+
+ if hasattr(bulkloader_config, 'loaders'):
+ for cls in bulkloader_config.loaders:
+ Loader.RegisterLoader(cls())
+
+ if hasattr(bulkloader_config, 'exporters'):
+ for cls in bulkloader_config.exporters:
+ Exporter.RegisterExporter(cls())
+
+ if hasattr(bulkloader_config, 'mappers'):
+ for cls in bulkloader_config.mappers:
+ Mapper.RegisterMapper(cls())
+
+ except NameError, e:
+ m = re.search(r"[^']*'([^']*)'.*", str(e))
+ if m.groups() and m.group(1) == 'Loader':
+ print >>sys.stderr, """
+The config file format has changed and you appear to be using an old-style
+config file. Please make the following changes:
+
+1. At the top of the file, add this:
+
+from google.appengine.tools.bulkloader import Loader
+
+2. For each of your Loader subclasses add the following at the end of the
+ __init__ definitioion:
+
+self.alias_old_names()
+
+3. At the bottom of the file, add this:
+
+loaders = [MyLoader1,...,MyLoaderN]
+
+Where MyLoader1,...,MyLoaderN are the Loader subclasses you want the bulkloader
+to have access to.
+"""
+ exit_fn(1)
+ else:
+ raise
+ except Exception, e:
+ if isinstance(e, NameClashError) or 'bulkloader_config' in vars() and (
+ hasattr(bulkloader_config, 'bulkloader') and
+ isinstance(e, bulkloader_config.bulkloader.NameClashError)):
+ print >> sys.stderr, (
+ 'Found both %s and %s while aliasing old names on %s.'%
+ (e.old_name, e.new_name, e.klass))
+ exit_fn(1)
+ else:
+ raise
+
+def GetArgument(kwargs, name, die_fn):
+ """Get the value of the key name in kwargs, or die with die_fn.
+
+ Args:
+ kwargs: A dictionary containing the options for the bulkloader.
+ name: The name of a bulkloader option.
+ die_fn: The function to call to exit the program.
+
+ Returns:
+ The value of kwargs[name] is name in kwargs
+ """
+ if name in kwargs:
+ return kwargs[name]
+ else:
+ print >>sys.stderr, '%s argument required' % name
+ die_fn()
+
+
+def _MakeSignature(app_id=None,
+ url=None,
+ kind=None,
+ db_filename=None,
+ perform_map=None,
+ download=None,
+ has_header=None,
+ result_db_filename=None,
+ dump=None,
+ restore=None):
+ """Returns a string that identifies the important options for the database."""
+ if download:
+ result_db_line = 'result_db: %s' % result_db_filename
+ else:
+ result_db_line = ''
+ return u"""
+ app_id: %s
+ url: %s
+ kind: %s
+ download: %s
+ map: %s
+ dump: %s
+ restore: %s
+ progress_db: %s
+ has_header: %s
+ %s
+ """ % (app_id, url, kind, download, perform_map, dump, restore, db_filename,
+ has_header, result_db_line)
+
+
+def ProcessArguments(arg_dict,
+ die_fn=lambda: sys.exit(1)):
+ """Processes non command-line input arguments.
+
+ Args:
+ arg_dict: Dictionary containing the values of bulkloader options.
+ die_fn: Function to call in case of an error during argument processing.
+
+ Returns:
+ A dictionary of bulkloader options.
+ """
+ app_id = GetArgument(arg_dict, 'app_id', die_fn)
+ url = GetArgument(arg_dict, 'url', die_fn)
+ dump = GetArgument(arg_dict, 'dump', die_fn)
+ restore = GetArgument(arg_dict, 'restore', die_fn)
+ filename = GetArgument(arg_dict, 'filename', die_fn)
+ batch_size = GetArgument(arg_dict, 'batch_size', die_fn)
+ kind = GetArgument(arg_dict, 'kind', die_fn)
+ db_filename = GetArgument(arg_dict, 'db_filename', die_fn)
+ config_file = GetArgument(arg_dict, 'config_file', die_fn)
+ result_db_filename = GetArgument(arg_dict, 'result_db_filename', die_fn)
+ download = GetArgument(arg_dict, 'download', die_fn)
+ log_file = GetArgument(arg_dict, 'log_file', die_fn)
+ perform_map = GetArgument(arg_dict, 'map', die_fn)
+
+ errors = []
+
+ if batch_size is None:
+ if download or perform_map:
+ arg_dict['batch_size'] = DEFAULT_DOWNLOAD_BATCH_SIZE
+ else:
+ arg_dict['batch_size'] = DEFAULT_BATCH_SIZE
+ elif batch_size <= 0:
+ errors.append('batch_size must be at least 1')
+
+ if db_filename is None:
+ arg_dict['db_filename'] = time.strftime(
+ 'bulkloader-progress-%Y%m%d.%H%M%S.sql3')
+
+ if result_db_filename is None:
+ arg_dict['result_db_filename'] = time.strftime(
+ 'bulkloader-results-%Y%m%d.%H%M%S.sql3')
+
+ if log_file is None:
+ arg_dict['log_file'] = time.strftime('bulkloader-log-%Y%m%d.%H%M%S')
+
+ required = '%s argument required'
+
+ if config_file is None and not dump and not restore:
+ errors.append('One of --config_file, --dump, or --restore is required')
+
+ if url is REQUIRED_OPTION:
+ errors.append(required % 'url')
+
+ if not filename and not perform_map:
+ errors.append(required % 'filename')
+
+ if kind is None:
+ if download or map:
+ errors.append('kind argument required for this operation')
+ elif not dump and not restore:
+ errors.append(
+ 'kind argument required unless --dump or --restore is specified')
+
+ if not app_id:
+ if url and url is not REQUIRED_OPTION:
+ (unused_scheme, host_port, unused_url_path,
+ unused_query, unused_fragment) = urlparse.urlsplit(url)
+ suffix_idx = host_port.find('.appspot.com')
+ if suffix_idx > -1:
+ arg_dict['app_id'] = host_port[:suffix_idx]
+ elif host_port.split(':')[0].endswith('google.com'):
+ arg_dict['app_id'] = host_port.split('.')[0]
+ else:
+ errors.append('app_id argument required for non appspot.com domains')
+
+ if errors:
+ print >>sys.stderr, '\n'.join(errors)
+ die_fn()
+
+ return arg_dict
+
+
+def ParseKind(kind):
+ if kind and kind[0] == '(' and kind[-1] == ')':
+ return tuple(kind[1:-1].split(','))
+ else:
+ return kind
+
+
+def _PerformBulkload(arg_dict,
+ check_file=CheckFile,
+ check_output_file=CheckOutputFile):
+ """Runs the bulkloader, given the command line options.
+
+ Args:
+ arg_dict: Dictionary of bulkloader options.
+ check_file: Used for dependency injection.
+ check_output_file: Used for dependency injection.
+
+ Returns:
+ An exit code.
+
+ Raises:
+ ConfigurationError: if inconsistent options are passed.
+ """
+ app_id = arg_dict['app_id']
+ url = arg_dict['url']
+ filename = arg_dict['filename']
+ batch_size = arg_dict['batch_size']
+ kind = arg_dict['kind']
+ num_threads = arg_dict['num_threads']
+ bandwidth_limit = arg_dict['bandwidth_limit']
+ rps_limit = arg_dict['rps_limit']
+ http_limit = arg_dict['http_limit']
+ db_filename = arg_dict['db_filename']
+ config_file = arg_dict['config_file']
+ auth_domain = arg_dict['auth_domain']
+ has_header = arg_dict['has_header']
+ download = arg_dict['download']
+ result_db_filename = arg_dict['result_db_filename']
+ loader_opts = arg_dict['loader_opts']
+ exporter_opts = arg_dict['exporter_opts']
+ mapper_opts = arg_dict['mapper_opts']
+ email = arg_dict['email']
+ passin = arg_dict['passin']
+ perform_map = arg_dict['map']
+ dump = arg_dict['dump']
+ restore = arg_dict['restore']
+
+ os.environ['AUTH_DOMAIN'] = auth_domain
+
+ kind = ParseKind(kind)
+
+ if not dump and not restore:
+ check_file(config_file)
+
+ if download and perform_map:
+ logger.error('--download and --map are mutually exclusive.')
+
+ if download or dump:
+ check_output_file(filename)
+ elif not perform_map:
+ check_file(filename)
+
+ if dump:
+ Exporter.RegisterExporter(DumpExporter(kind, result_db_filename))
+ elif restore:
+ Loader.RegisterLoader(RestoreLoader(kind))
+ else:
+ LoadConfig(config_file)
+
+ os.environ['APPLICATION_ID'] = app_id
+
+ throttle_layout = ThrottleLayout(bandwidth_limit, http_limit, rps_limit)
+ logger.info('Throttling transfers:')
+ logger.info('Bandwidth: %s bytes/second', bandwidth_limit)
+ logger.info('HTTP connections: %s/second', http_limit)
+ logger.info('Entities inserted/fetched/modified: %s/second', rps_limit)
+
+ throttle = remote_api_throttle.Throttle(layout=throttle_layout)
+ signature = _MakeSignature(app_id=app_id,
+ url=url,
+ kind=kind,
+ db_filename=db_filename,
+ download=download,
+ perform_map=perform_map,
+ has_header=has_header,
+ result_db_filename=result_db_filename,
+ dump=dump,
+ restore=restore)
+
+
+ max_queue_size = max(DEFAULT_QUEUE_SIZE, 3 * num_threads + 5)
+
+ if db_filename == 'skip':
+ progress_db = StubProgressDatabase()
+ elif not download and not perform_map and not dump:
+ progress_db = ProgressDatabase(db_filename, signature)
+ else:
+ progress_db = ExportProgressDatabase(db_filename, signature)
+
+ return_code = 1
+
+ if not download and not perform_map and not dump:
+ loader = Loader.RegisteredLoader(kind)
+ try:
+ loader.initialize(filename, loader_opts)
+ workitem_generator_factory = GetCSVGeneratorFactory(
+ kind, filename, batch_size, has_header)
+
+ app = BulkUploaderApp(arg_dict,
+ workitem_generator_factory,
+ throttle,
+ progress_db,
+ ProgressTrackerThread,
+ max_queue_size,
+ RequestManager,
+ DataSourceThread,
+ Queue.Queue)
+ try:
+ return_code = app.Run()
+ except AuthenticationError:
+ logger.info('Authentication Failed')
+ finally:
+ loader.finalize()
+ elif not perform_map:
+ result_db = ResultDatabase(result_db_filename, signature)
+ exporter = Exporter.RegisteredExporter(kind)
+ try:
+ exporter.initialize(filename, exporter_opts)
+
+ def KeyRangeGeneratorFactory(request_manager, progress_queue,
+ progress_gen):
+ return KeyRangeItemGenerator(request_manager, kind, progress_queue,
+ progress_gen, DownloadItem)
+
+ def ExportProgressThreadFactory(progress_queue, progress_db):
+ return ExportProgressThread(kind,
+ progress_queue,
+ progress_db,
+ result_db)
+
+ app = BulkDownloaderApp(arg_dict,
+ KeyRangeGeneratorFactory,
+ throttle,
+ progress_db,
+ ExportProgressThreadFactory,
+ 0,
+ RequestManager,
+ DataSourceThread,
+ Queue.Queue)
+ try:
+ return_code = app.Run()
+ except AuthenticationError:
+ logger.info('Authentication Failed')
+ finally:
+ exporter.finalize()
+ elif not download:
+ mapper = Mapper.RegisteredMapper(kind)
+ try:
+ mapper.initialize(mapper_opts)
+ def KeyRangeGeneratorFactory(request_manager, progress_queue,
+ progress_gen):
+ return KeyRangeItemGenerator(request_manager, kind, progress_queue,
+ progress_gen, MapperItem)
+
+ def MapperProgressThreadFactory(progress_queue, progress_db):
+ return MapperProgressThread(kind,
+ progress_queue,
+ progress_db)
+
+ app = BulkMapperApp(arg_dict,
+ KeyRangeGeneratorFactory,
+ throttle,
+ progress_db,
+ MapperProgressThreadFactory,
+ 0,
+ RequestManager,
+ DataSourceThread,
+ Queue.Queue)
+ try:
+ return_code = app.Run()
+ except AuthenticationError:
+ logger.info('Authentication Failed')
+ finally:
+ mapper.finalize()
+ return return_code
+
+
+def SetupLogging(arg_dict):
+ """Sets up logging for the bulkloader.
+
+ Args:
+ arg_dict: Dictionary mapping flag names to their arguments.
+ """
+ format = '[%(levelname)-8s %(asctime)s %(filename)s] %(message)s'
+ debug = arg_dict['debug']
+ log_file = arg_dict['log_file']
+
+ logger.setLevel(logging.DEBUG)
+
+ logger.propagate = False
+
+ file_handler = logging.FileHandler(log_file, 'w')
+ file_handler.setLevel(logging.DEBUG)
+ file_formatter = logging.Formatter(format)
+ file_handler.setFormatter(file_formatter)
+ logger.addHandler(file_handler)
+
+ console = logging.StreamHandler()
+ level = logging.INFO
+ if debug:
+ level = logging.DEBUG
+ console.setLevel(level)
+ console_format = '[%(levelname)-8s] %(message)s'
+ formatter = logging.Formatter(console_format)
+ console.setFormatter(formatter)
+ logger.addHandler(console)
+
+ logger.info('Logging to %s', log_file)
+
+ remote_api_throttle.logger.setLevel(level)
+ remote_api_throttle.logger.addHandler(file_handler)
+ remote_api_throttle.logger.addHandler(console)
+
+ appengine_rpc.logger.setLevel(logging.WARN)
+
+ adaptive_thread_pool.logger.setLevel(logging.DEBUG)
+ adaptive_thread_pool.logger.addHandler(console)
+ adaptive_thread_pool.logger.addHandler(file_handler)
+ adaptive_thread_pool.logger.propagate = False
+
+
+def Run(arg_dict):
+ """Sets up and runs the bulkloader, given the options as keyword arguments.
+
+ Args:
+ arg_dict: Dictionary of bulkloader options
+
+ Returns:
+ An exit code.
+ """
+ arg_dict = ProcessArguments(arg_dict)
+
+ SetupLogging(arg_dict)
+
+ return _PerformBulkload(arg_dict)
+
+
+def main(argv):
+ """Runs the importer from the command line."""
+
+ arg_dict = ParseArguments(argv)
+
+ errors = ['%s argument required' % key
+ for (key, value) in arg_dict.iteritems()
+ if value is REQUIRED_OPTION]
+ if errors:
+ print >>sys.stderr, '\n'.join(errors)
+ PrintUsageExit(1)
+
+ SetupLogging(arg_dict)
+ return _PerformBulkload(arg_dict)
+
+
+if __name__ == '__main__':
+ sys.exit(main(sys.argv))