web-dev-qa-db-ja.com

SPARK SQL-DataFramesとJDBCを使用してMySqlテーブルを更新する

Spark SQL DataFramesとJDBC接続を使用して、MySqlのデータを挿入および更新しようとしています。

SaveMode.Appendを使用して新しいデータを挿入することに成功しました。 MySqlテーブルにすでに存在するデータをSpark SQLから更新する方法はありますか?

挿入するコードは次のとおりです。

myDataFrame.write.mode(SaveMode.Append).jdbc(JDBCurl,mySqlTable,connectionProperties)

SaveMode.Overwriteに変更すると、テーブル全体が削除され、新しいテーブルが作成されます。MySqlで利用可能な「ON DUPLICATE KEY UPDATE」のようなものを探しています。

19
nicola

それは不可能。今のところ(Spark 1.6.0/2.2.0 SNAPSHOT)Spark DataFrameWriterは4つの書き込みモードのみをサポートしています:

  • SaveMode.Overwrite:既存のデータを上書きします。
  • SaveMode.Append:データを追加します。
  • SaveMode.Ignore:操作を無視します(つまり、ノーオペレーション)。
  • SaveMode.ErrorIfExists:デフォルトオプション。実行時に例外をスローします。

たとえば、mapPartitionsを使用して手動で挿入できます(UPSERT操作がべき等であり、実装が簡単なため)、一時テーブルに書き込み、手動でupsertを実行するか、トリガーを使用します。

一般に、バッチ操作のアップサート動作を実現し、適切なパフォーマンスを維持するのは簡単なことではありません。一般的に、複数の同時トランザクションが存在することを覚えておく必要があります(各パーティションに1つ)。書き込みの競合がないことを確認する必要があります(通常、アプリケーション固有のパーティション分割を使用)。実際には、一時テーブルへの書き込みとバッチ書き込みを実行し、データベース内のアップサート部分を直接解決する方が良い場合があります。

21
zero323

Spark= upsertingのような非常に一般的なケースではSaveMode.Upsertモードが存在しないのは残念です。

zero322は一般的に正しいですが、このような置換機能を提供することは(パフォーマンスに妥協して)可能であるべきだと思います。

また、この場合にJavaコードを提供したい。もちろん、spark-要件の適切な基礎となるため、ニーズに合わせて変更します。

myDF.repartition(20); //one connection per partition, see below

myDF.foreachPartition((Iterator<Row> t) -> {
            Connection conn = DriverManager.getConnection(
                    Constants.DB_JDBC_CONN,
                    Constants.DB_JDBC_USER,
                    Constants.DB_JDBC_PASS);

            conn.setAutoCommit(true);
            Statement statement = conn.createStatement();

            final int batchSize = 100000;
            int i = 0;
            while (t.hasNext()) {
                Row row = t.next();
                try {
                    // better than REPLACE INTO, less cycles
                    statement.addBatch(("INSERT INTO mytable " + "VALUES ("
                            + "'" + row.getAs("_id") + "', 
                            + "'" + row.getStruct(1).get(0) + "'
                            + "')  ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") + "';"));
                    //conn.commit();

                    if (++i % batchSize == 0) {
                        statement.executeBatch();
                    }
                } catch (SQLIntegrityConstraintViolationException e) {
                    //should not occur, nevertheless
                    //conn.commit();
                } catch (SQLException e) {
                    e.printStackTrace();
                } finally {
                    //conn.commit();
                    statement.executeBatch();
                }
            }
            int[] ret = statement.executeBatch();

            System.out.println("Ret val: " + Arrays.toString(ret));
            System.out.println("Update count: " + statement.getUpdateCount());
            conn.commit();

            statement.close();
            conn.close();
2
Aydin K.

org.Apache.spark.sql.execution.datasources.jdbcJdbcUtils.scalainsert intoreplace intoに上書きします

import Java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}

import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import com.typesafe.scalalogging.Logger
import org.Apache.spark.sql.catalyst.InternalRow
import org.Apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}
import org.Apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.Apache.spark.sql.types._
import org.Apache.spark.sql.{DataFrame, Row}

/**
  * Util functions for JDBC tables.
  */
object UpdateJdbcUtils {

  val logger = Logger(this.getClass)

  /**
    * Returns a factory for creating connections to the given JDBC URL.
    *
    * @param options - JDBC options that contains url, table and other information.
    */
  def createConnectionFactory(options: JDBCOptions): () => Connection = {
    val driverClass: String = options.driverClass
    () => {
      DriverRegistry.register(driverClass)
      val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
        case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
        case d if d.getClass.getCanonicalName == driverClass => d
      }.getOrElse {
        throw new IllegalStateException(
          s"Did not find registered driver with class $driverClass")
      }
      driver.connect(options.url, options.asConnectionProperties)
    }
  }

  /**
    * Returns a PreparedStatement that inserts a row into table via conn.
    */
  def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
  : PreparedStatement = {
    val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"
    conn.prepareStatement(sql)
  }

  /**
    * Retrieve standard jdbc types.
    *
    * @param dt The datatype (e.g. [[org.Apache.spark.sql.types.StringType]])
    * @return The default JdbcType for this DataType
    */
  def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
    dt match {
      case IntegerType => Option(JdbcType("INTEGER", Java.sql.Types.INTEGER))
      case LongType => Option(JdbcType("BIGINT", Java.sql.Types.BIGINT))
      case DoubleType => Option(JdbcType("DOUBLE PRECISION", Java.sql.Types.DOUBLE))
      case FloatType => Option(JdbcType("REAL", Java.sql.Types.FLOAT))
      case ShortType => Option(JdbcType("INTEGER", Java.sql.Types.SMALLINT))
      case ByteType => Option(JdbcType("BYTE", Java.sql.Types.TINYINT))
      case BooleanType => Option(JdbcType("BIT(1)", Java.sql.Types.BIT))
      case StringType => Option(JdbcType("TEXT", Java.sql.Types.CLOB))
      case BinaryType => Option(JdbcType("BLOB", Java.sql.Types.BLOB))
      case TimestampType => Option(JdbcType("TIMESTAMP", Java.sql.Types.TIMESTAMP))
      case DateType => Option(JdbcType("DATE", Java.sql.Types.DATE))
      case t: DecimalType => Option(
        JdbcType(s"DECIMAL(${t.precision},${t.scale})", Java.sql.Types.DECIMAL))
      case _ => None
    }
  }

  private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
  // for `MutableRow`. The last argument `Int` means the index for the value to be set in
  // the row and also used for the value in `ResultSet`.
  private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  /**
    * Saves a partition of a DataFrame to the JDBC database.  This is done in
    * a single database transaction (unless isolation level is "NONE")
    * in order to avoid repeatedly inserting data as much as possible.
    *
    * It is still theoretically possible for rows in a DataFrame to be
    * inserted into the database more than once if a stage somehow fails after
    * the commit occurs but before the stage can return successfully.
    *
    * This is not a closure inside saveTable() because apparently cosmetic
    * implementation changes elsewhere might easily render such a closure
    * non-Serializable.  Instead, we explicitly close over all variables that
    * are used.
    */
  def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     nullTypes: Array[Int],
                     batchSize: Int,
                     dialect: JdbcDialect,
                     isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logger.warn(s"Requested isolation level $isolationLevel is not supported; " +
              s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = insertStatement(conn, table, rddSchema, dialect)
      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
        .map(makeSetter(conn, dialect, _))
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          if (e.getCause == null) {
            e.initCause(cause)
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

  /**
    * Saves the RDD to the database in a single transaction.
    */
  def saveTable(
                 df: DataFrame,
                 url: String,
                 table: String,
                 options: JDBCOptions) {
    val dialect = JdbcDialects.get(url)
    val nullTypes: Array[Int] = df.schema.fields.map { field =>
      getJdbcType(field.dataType, dialect).jdbcNullType
    }

    val rddSchema = df.schema
    val getConnection: () => Connection = createConnectionFactory(options)
    val batchSize = options.batchSize
    val isolationLevel = options.isolationLevel
    df.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
    )
  }

  private def makeSetter(
                          conn: Connection,
                          dialect: JdbcDialect,
                          dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setString(pos + 1, row.getString(pos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[Java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[Java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase.split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
}

使用法:

val url = s"jdbc:mysql://$Host/$database?useUnicode=true&characterEncoding=UTF-8"

val parameters: Map[String, String] = Map(
  "url" -> url,
  "dbtable" -> table,
  "driver" -> "com.mysql.jdbc.Driver",
  "numPartitions" -> numPartitions.toString,
  "user" -> user,
  "password" -> password
)
val options = new JDBCOptions(parameters)

for (d <- data) {
  UpdateJdbcUtils.saveTable(d, url, table, options)
}

ps:デッドロックに注意してください。頻繁にデータを更新するのではなく、緊急時に再実行するだけで使用してください。sparkはこのオフィシャルをサポートしません。

2
user1442346

PYSPARKではそれができなかったため、odbcを使用することにしました。

url = "jdbc:sqlserver://xxx:1433;databaseName=xxx;user=xxx;password=xxx"
df.write.jdbc(url=url, table="__TableInsert", mode='overwrite')
cnxn  = pyodbc.connect('Driver={ODBC Driver 17 for SQL Server};Server=xxx;Database=xxx;Uid=xxx;Pwd=xxx;', autocommit=False) 
try:
    crsr = cnxn.cursor()
    # DO UPSERTS OR WHATEVER YOU WANT
    crsr.execute("DELETE FROM Table")
    crsr.execute("INSERT INTO Table (Field) SELECT Field FROM __TableInsert")
    cnxn.commit()
except:
    cnxn.rollback()
cnxn.close()
0
Anton Shelin

zero323の答えは正しいです。JayDeBeApiパッケージを使用してこれを回避できることを追加したいだけです。 https://pypi.python.org/pypi/JayDeBeApi/

mysqlテーブルのデータを更新します。すでにmysql jdbcドライバーがインストールされているため、問題が少ない可能性があります。

JayDeBeApiモジュールにより、PythonコードからJava JDBCを使用してデータベースに接続できます。Python DB-APIそのデータベースへのv2.0。

PythonのAnacondaディストリビューションを使用し、JayDeBeApi pythonパッケージが標準で付属しています。

上記のリンクの例をご覧ください。

0
Tagar