私はこれに出くわしました ページ
1)微調整が完了した後、センテンスレベルの埋め込み([CLS]
トークンで指定された埋め込み)を取得したいと考えています。どうすればできますか?
2)そのページのコードがテストデータの結果を返すのに長い時間がかかることにも気付きました。何故ですか?モデルをトレーニングしたときは、テスト予測を取得しようとしたときと比べて時間がかかりませんでした。そのページのコードから、以下のコードブロックは使用しませんでした
test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None,
text_a = x[DATA_COLUMN],
text_b = None,
label = x[LABEL_COLUMN]), axis = 1
test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)
test_input_fn = run_classifier.input_fn_builder(
features=test_features,
seq_length=MAX_SEQ_LENGTH,
is_training=False,
drop_remainder=False)
estimator.evaluate(input_fn=test_input_fn, steps=None)
むしろ、私は自分のテストデータ全体で以下の関数を使用しました
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in Zip(in_sentences, predictions)]
3)予測の確率を取得するにはどうすればよいですか。 keras predict
メソッドを使用する方法はありますか?
質問2の更新-getPrediction
関数を使用して20000のトレーニング例をテストできますか?.
1)から BERTドキュメント
出力ディクショナリには次が含まれます。
pooled_output:シェイプ全体のシーケンス全体のプールされた出力[batch_size、hidden_size]。 sequence_output:入力シーケンス内のすべてのトークンの形状[batch_size、max_sequence_length、hidden_size]。
CLSベクトルに対応する_pooled_output
_ベクトルを追加しました。
3)ログ確率を受け取ります。通常の確率を取得するには、softmax
を適用するだけです。
あとは、モデルがそれを報告するだけです。ログプロブを残しましたが、もう必要ありません。
コードの変更を確認します。
_def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
num_labels):
"""Creates a classification model."""
bert_module = hub.Module(
BERT_MODEL_HUB,
trainable=True)
bert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids)
bert_outputs = bert_module(
inputs=bert_inputs,
signature="tokens",
as_dict=True)
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
output_layer = bert_outputs["pooled_output"]
pooled_output = output_layer
hidden_size = output_layer.shape[-1].value
# Create our own layer to tune for politeness data.
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
# Dropout helps prevent overfitting
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
# Convert labels into one-hot encoding
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
# If we're predicting, we want predicted labels and the probabiltiies.
if is_predicting:
return (predicted_labels, log_probs, probs, pooled_output)
# If we're train/eval, compute loss between predicted and actual label
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, predicted_labels, log_probs, probs, pooled_output)
_
次にmodel_fn_builder()
にこれらの値のサポートを追加します。
_ # this should be changed in both places
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
# return dictionary of all the values you wanted
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
_
それに応じてgetPrediction()
を調整すると、最終的に予測は次のようになります。
_('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32), <= Probability
array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
'Negative', <= Label
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
...
and this is 768-d [CLS] vector (sentence embedding).
_
2)について:私の最後のトレーニングは約5分かかり、テストは約40秒でした。非常に合理的です。
[〜#〜]更新[〜#〜]
2万サンプルの場合、トレーニングに12:48、テストに2:07分かかりました。
10kサンプルの場合、タイミングはそれぞれ8:40と1:07です。
もちろん、残りの変更点は次のとおりです。
# model_fn_builder actually creates our model function
# using the passed parameters for num_labels, learning_rate, etc.
def model_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)
# TRAIN and EVAL
if not is_predicting:
(loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
train_op = bert.optimization.create_optimizer(
loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
# Calculate evaluation metrics.
def metric_fn(label_ids, predicted_labels):
accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
f1_score = tf.contrib.metrics.f1_score(
label_ids,
predicted_labels)
auc = tf.metrics.auc(
label_ids,
predicted_labels)
recall = tf.metrics.recall(
label_ids,
predicted_labels)
precision = tf.metrics.precision(
label_ids,
predicted_labels)
true_pos = tf.metrics.true_positives(
label_ids,
predicted_labels)
true_neg = tf.metrics.true_negatives(
label_ids,
predicted_labels)
false_pos = tf.metrics.false_positives(
label_ids,
predicted_labels)
false_neg = tf.metrics.false_negatives(
label_ids,
predicted_labels)
return {
"eval_accuracy": accuracy,
"f1_score": f1_score,
"auc": auc,
"precision": precision,
"recall": recall,
"true_positives": true_pos,
"true_negatives": true_neg,
"false_positives": false_pos,
"false_negatives": false_neg
}
eval_metrics = metric_fn(label_ids, predicted_labels)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)
else:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
# Return the actual model function in the closure
return model_fn
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction in Zip(in_sentences, predictions)]
そして、最初の出力(その他は、答えのbc 30Kシンボル制限を切り捨てています):
[('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32),
array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
'Negative',
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
-0.9999866 , 0.5820049 , 0.3257555 , -0.81900954, -0.8326617 ,
0.87788117, -0.7791749 , 0.11098853, 0.67873836, 0.9999771 ,
0.9833652 , -0.8420576 , 0.83076835, 0.37272754, 0.8667175 ,
0.792386 , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
0.55752313, 1. , -0.72632766, -0.4752956 , -0.9999852 ,
-0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
0.92974335, -0.8601105 , -0.8113003 , 0.7660112 , 0.9313508 ,
0.21427669, -0.45660907, 0.99970686, 0.56852764, -0.9997675 ,
-0.9999096 , 0.8247045 , 0.7205424 , 0.47192624, -0.7523966 ,
-0.9588541 , -0.48866934, 0.9809366 , -0.07110611, -0.99886 ,
-0.63922834, -0.68144 , -1. , 0.8531816 , 0.26078308,
-0.99898577, -0.99968046, 0.6711601 , 0.99857473, -0.99990964,
1. , -0.97127694, -0.10644457, 0.46306637, -0.32486317,
-0.68167734, 0.43291137, -0.996574 , 0.05164305, 0.9897354 ,
0.93853104, 0.94800174, 0.9995697 , 0.6532897 , 0.93846226,
-0.6281378 , 0.5574107 , 0.725278 , 0.74160355, -0.6486919 ,
0.88869256, 0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
0.17409436, 0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
0.5546853 , 0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
0.94367695, 0.6834396 , -0.72266734, 0.99376386, -0.76821744,
0.4485644 , 0.99982166, 1. , 0.9260674 , 0.9759094 ,
0.9397613 , 0.8128903 , -0.7918152 , 0.30299878, -0.95160294,
0.25385544, -0.57780135, -0.9999994 , 0.9168113 , -0.36585295,
0.9798102 , 0.95976156, -0.99428 , 0.6471789 , -0.9948078 ,
-0.9686591 , 0.93615085, -0.11481134, 0.87566274, -0.91601896,
0.9952683 , 0.26532048, 0.99861896, 0.79298306, 0.5872364 ,
-0.56314534, 0.96794534, 0.9999797 , 0.9879324 , 0.5003342 ,
0.9516269 , -0.8878316 , -0.9665091 , -0.88037425, 0.8356687 ,
-0.71543014, -0.99985015, -0.9414574 , 0.8681497 , 0.950698 ,
-0.8007153 , 0.78748596, 0.9999305 , 0.40210736, 0.4856055 ,
-0.9390776 , 0.63564163, -0.85989815, -0.8421344 , -0.99436 ,
0.78081733, -0.97038007, 0.39290914, 0.7834218 , 0.88715357,
-0.03653741, 0.99126273, -0.96559966, 0.11924513, -0.99363935,
-0.9901692 , 0.963858 , 0.5713922 , 0.5676979 , 0.69982123,
0.858003 , 0.9983819 , -0.87965024, 0.46213093, -0.3256273 ,
0.77337253, 0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
-0.93148243, 0.09674019, 0.09448949, -0.7453027 , -0.78955775,
-0.6304773 , -0.5597632 , 0.992308 , 0.7769483 , 0.04146893,
-0.15876745, -0.7682887 , -0.5231416 , 0.7871302 , 0.9503481 ,
-0.9607153 , 0.99047405, -0.9948017 , -0.82257754, 0.9990552 ,
0.79346406, -0.78624016, 0.8760266 , -0.7855991 , 0.13444276,
-0.7183107 , -0.9999819 , 0.7019429 , -0.918913 , -0.6569654 ,
0.9998794 , -0.33805153, -0.9427715 , 0.10419375, -0.94257164,
0.9187495 , -0.9994855 , -0.99979955, -0.9277688 , 0.6353426 ,
0.9994905 , 0.90688777, 0.9992008 , 0.7817533 , -0.9996674 ,
-0.999962 , -0.13310781, -0.82505953, 0.9997485 , 0.82616794,
-0.999998 , 0.45386457, 0.6069964 , 0.52272975, 0.8811922 ,
0.52668494, -0.9994814 , -0.21601789, -0.99882716, 0.90246916,
0.94196504, 0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
0.7727592 , 0.9936947 , 0.98021245, -0.77723926, -0.785372 ,
0.5150317 , 0.9983137 , -0.7461883 , 0.3311537 , -0.63709795,
-0.6487831 , -0.9173727 , 0.9997706 , -0.9999893 , -1. ,
0.60389155, -0.6516268 , -0.95422006, 1. , 0.09109057,
-0.99999994, 0.99998957, 1. , -0.19451752, 0.94624877,
-0.2761865 , 1. , 0.52399474, 0.70230734, 0.5218801 ,
-0.99716544, -0.70075685, -0.99992603, 1. , -0.9785006 ,
0.22457084, -0.5356722 , -0.9991887 , 0.7062409 , 0.66816545,
-0.90308225, -0.8084922 , 0.50301254, -0.7062079 , 0.9998321 ,
0.9823206 , 0.9984027 , 0.9948857 , -1. , -0.7067878 ,
0.975454 , 0.87161005, -0.9882297 , 0.8296374 , -0.88615334,
0.4316883 , 0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
-0.84212875, 0.78632677, -0.5131366 , -0.996949 , -0.75479275,
-0.06342169, 0.92238575, 0.66769385, 0.9926053 , -0.78391105,
0.9976865 , 0.07086544, 0.34079495, 0.69730175, -0.99970955,
-1. , -0.9860551 , 0.89584446, -0.96889114, -0.90435815,
0.944296 , -1. , -0.9931756 , -0.7014334 , -0.6742562 ,
-0.96786517, 0.848328 , 0.8903087 , -0.9998633 , 0.73993397,
0.99345684, 0.9691821 , 0.87563246, -0.6073146 , -0.9999999 ,
0.90763575, 0.30225936, -0.47824544, 0.7179979 , 0.9450465 ,
0.9715953 , -0.5422173 , 0.99995065, -0.5920663 , 0.92390317,
-0.9670669 , -0.3623574 , 0.74825 , -0.7817521 , 0.9888685 ,
-0.7653631 , -0.8933355 , 0.9481424 , 0.97803396, -0.9999731 ,
-0.89597356, 0.35502487, -0.7190486 , 0.30777818, 0.55025375,
0.6365793 , -0.99094397, -1. , 0.93482614, -0.99970514,
0.98721176, 0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
0.57238674, 0.97475344, -0.9963499 , 0.98476464, 0.40495875,
-0.7001948 , -0.40898973, 0.61900675, -1. , -0.9371812 ,
-0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245 ,
-0.99991447, -0.5872595 , 0.5835767 , 0.7003338 , -0.9761974 ,
0.99995846, 0.33676207, 0.9079994 , -0.76412004, -0.7648706 ,
0.68863285, 0.43983305, 0.74911463, -0.99995685, -0.6692586 ,
-0.45761266, -0.9980771 , -1. , 0.31244457, -0.8834693 ,
0.9388263 , -0.987405 , 1. , 0.9512058 , 0.23448633,
0.37940192, 0.99989796, 0.8402514 , -0.84526414, 0.7378776 ,
-0.9996204 , -0.99434114, 0.9987527 , 0.5569713 , 0.99648696,
-0.9933159 , -0.13116199, 0.9999992 , 0.9642579 , -0.48285434,
-0.97517425, 0.7185596 , 0.5286405 , 0.9902838 , 0.7796022 ,
-0.80703837, 0.2376029 , 0.534117 , -0.9999413 , 0.99828076,
0.9998345 , 0.93249476, 0.3620626 , 0.7567034 , -0.9222681 ,
0.97832036, 0.9999682 , 0.6433209 , -1. , 0.9268615 ,
-0.9999511 , -0.9145363 , -0.9213852 , 0.7606066 , -0.5501025 ,
-0.99999434, -0.7783993 , 0.9999771 , 0.99980384, 0.987094 ,
0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
-0.9919206 , -0.49190572, 0.70230234, -0.31277484, -0.99999964,
0.828591 , 0.6363776 , 0.86796165, 0.81575817, 0.7782955 ,
0.9436437 , -1. , -0.7509046 , -0.9946139 , -0.6647415 ,
0.999543 , 0.9312092 , -1. , 0.5639159 , 0.9482462 ,
-0.9289936 , -0.9678435 , 0.60937124, -0.987818 , 0.5511619 ,
0.75886583, -0.48466644, -0.71833754, 0.8042149 , 0.9154103 ,
-0.8177468 , 0.7195895 , -0.82283056, 0.24990956, -1. ,
0.7729634 , 0.84048635, 0.7989596 , 0.9469012 , -0.9898951 ,
-0.92565274, 0.74726975, 0.78213847, -0.672894 , -0.58831286,
-0.8039038 , -0.72197783, 0.5289216 , -0.9998796 , -0.9904479 ,
0.9996592 , -0.28984115, 0.23964961, -0.7427149 , -0.662416 ,
-1. , -0.5538268 , -0.9945287 , -0.63471127, 0.5896127 ,
-0.48429146, 0.9976076 , -0.94329506, -0.49143887, 0.7695602 ,
0.8638134 , -0.82130384, 0.50105464, 0.9336961 , -0.24716294,
-0.6922282 , -0.02228704, 0.75649065, 0.82303154, -0.30867255,
-0.9602714 , 0.64568967, 0.314201 , -0.4811752 , 0.27952817,
0.9227022 , 0.88095886, 0.89470226, 1. , -0.19237158,
1. , -0.991253 , -0.9991121 , 0.5637482 , -0.75780976,
-0.3904836 , -0.9881965 , -0.2912058 , 0.9998215 , 0.9869475 ,
-0.12784953, 0.81566185, 0.9787118 , -0.17835459, -0.7027824 ,
0.72269535, -0.18194303, 0.9968796 , 0.03490257, 0.7751488 ,
-1. , -0.7761089 , 0.85105944, 0.9968074 , -0.8156342 ,
0.5300792 , -1. , 0.99626255, -0.7515625 , -0.6672005 ,
0.9792111 , 0.8660997 , -0.69161206, 0.32184905, 0.9071073 ,
0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
0.70707524, 0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
-0.94375473, -0.91838425, 0.64272994, 0.9375524 , 0.6609169 ,
-0.88743365, -0.9534722 , -0.47888806, -1. , -0.5251781 ,
0.8274516 , 0.9326824 , 0.8961964 , 0.5295862 , 0.43714878,
-0.7488347 , -0.75295556, -0.5187054 , 0.75924635, -0.7862662 ,
0.99981725, -0.80290836, 0.97651815, 0.99763787, -0.29619345,
-0.1252967 , 0.33606276, -0.65137684, -0.9680231 , 0.77586985,
0.22347753, 0.27245504, -0.07826214, -0.8383849 , -0.85373163,
1. , -0.4563588 , -0.91339815, -0.9999861 , 0.66063935,
-0.985843 , -0.7818757 , -0.7000497 , -0.6840764 , 0.9995542 ,
0.60819125, 0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
-0.8771755 , 0.71411085, 0.8113569 , 0.9974196 , -0.75211936,
0.63400257, -0.8272833 , 0.99780786, 0.9965285 , 0.59551436,
-0.9876875 , -0.04439292, 0.9939223 , 0.9993717 , -0.9965501 ,
-0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
-0.95355797, -0.67561924, 0.9997761 , -0.85473967, 0.998495 ,
-0.95756954, 0.633171 , 0.4570475 , -0.5316367 , -0.9663824 ,
0.9567106 , -0.45497724, 0.12964879, 0.9964744 , -0.9711668 ,
0.69636106, -0.9178346 , 0.8313186 , 0.69686604, 0.8141587 ,
-0.33600506, 0.94798595, 0.8800869 , 0.15029034, -0.91185665,
0.6322724 , -0.9971475 , 0.71948224, 0.9695236 , 0.84242374,
0.99995124, 0.5982563 , -0.98341423, 0.61301434, 0.9997318 ,
-0.9981808 , -0.65651804, -0.8484874 , -0.9961815 , 0.9030814 ,
0.87141925, 0.8897381 , -0.92870414, 0.07134341, 0.8739935 ,
0.91630197, -0.9465984 , -0.59741104, -1. , 0.9989559 ,
0.99991184, 0.67439264, 0.92025673, -0.60730827, 0.8362061 ,
1. , -0.70801497, 0.9883806 , -0.9984141 , 0.9919259 ,
-0.998869 , 0.9976203 , 0.9888036 , 0.8556838 , -0.9722744 ,
-0.99810714, 0.8182833 , 0.98808485, 0.6643728 , 0.99212515,
-0.99988 , 0.26405996, 0.93139845, 0.99021816, 0.6846886 ,
0.9986462 , 0.92254627, -0.6406982 ], dtype=float32)),
('The acting was a bit lacking',
array([0.9921152 , 0.00788479], dtype=float32),
array([-0.00791603, -4.842819 ], dtype=float32),
'Negative',
array([ 0.67417824, 0.8235167 , 0.99999565, -0.8565971 , -0.99499583,
0.8219966 , -0.9185583 , -0.5234593 , 0.99962074, 0.99999714,
0.9507927 , -0.9996754 , 0.22211392, -0.99826247, 0.7562492 ,
0.93803996, 0.82738185, 0.4773049 , -0.73478544, 0.85207295,