在多标签分类问题中应用自定义损失函数时出现TypeError

我正在尝试使用huggingface transformers库中的BERT解决一个多标签文本分类问题。模型定义如下:

def create_model(encoder, nb_classes=3, lr=1e-5):    # inputs    input_ids = tf.keras.Input(shape=(512,), ragged=False,                               dtype=tf.int32, name='input_ids')    input_attention_mask = tf.keras.Input(shape=(512,), ragged=False,                                          dtype=tf.int32, name='attention_mask')    # transformer    output = encoder({'input_ids': input_ids,                       'attention_mask': input_attention_mask})[0]    Y = tf.keras.layers.BatchNormalization()(output)    Y = tf.keras.layers.Dense(nb_classes, activation='sigmoid')(Y)    # compilation    model = tf.keras.Model(inputs=[input_ids, input_attention_mask],                            outputs=[Y])    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)    # losses    # loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)    # loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)    model.compile(optimizer=optimizer,                   loss=multilabel_loss, metrics=['acc'])    model.summary()    return model

如您所见,我尝试使用tf.keras.losses,但它不起作用(抛出AttributeError: 'Tensor' object has no attribute 'nested_row_splits'),所以我手动定义了一个简单的交叉熵:

def multilabel_loss(y_true, y_pred):    y_pred = tf.convert_to_tensor(y_pred)    y_true = tf.cast(y_true, y_pred.dtype)    cross_entropy = -tf.reduce_sum((y_true*tf.math.log(y_pred + 1e-8) + (1 - y_true) * tf.math.log(1 - y_pred + 1e-8)),                                   name='xentropy')    return cross_entropy

模型是使用strategy.scope()创建的,如下所示,使用’distil-bert-uncased’作为检查点:

with strategy.scope():    encoder = TFAutoModelForSequenceClassification.from_pretrained(checkpoint)    #encoder = TFRobertaForSequenceClassification.from_pretrained(checkpoint)    model = create_model(encoder)

标签是二进制数组:

163350    [0, 0, 1]118940    [0, 0, 1]65243     [0, 0, 1]30011     [0, 0, 1]189713    [0, 1, 0]

它们与标记化的文本一起在下一个函数中组合成一个tf.dataset:

def tf_text_data_prep(df):    """    input: takes pandas dataframe    output: returns tokenized tf.Dataset    """    hugging_ds = Dataset.from_pandas(df)    tokenized_ds = hugging_ds.map(                      tokenize_function,                      batched=True,                      num_proc=strategy.num_replicas_in_sync,                      remove_columns=["Text", '__index_level_0__'],                      load_from_cache_file=True                       )        # Convert to tensorflow    tf_dataset = tokenized_ds.with_format("tensorflow")    features = {x: tf_dataset[x].to_tensor() for x in tokenizer.model_input_names}    tf_data = tf.data.Dataset.from_tensor_slices((features, tf_dataset["label"]))    return tf_data

问题是当我启动训练时,我会得到以下错误:

TypeError                                 Traceback (most recent call last)<ipython-input-62-720b4634d50e> in <module>()----> 1 get_ipython().run_cell_magic('time', '', 'steps_per_epoch = int(BUFFER_SIZE // BATCH_SIZE)\nprint(\n    f"Model Params:\\nbatch_size: {BATCH_SIZE}\\nEpochs: {EPOCHS}\\n"\n    f"Step p. Epoch: {steps_per_epoch}\\n"\n    f"Initial Learning rate: {INITAL_LEARNING_RATE}"\n)\nhistory = model.fit(\n    train_ds,\n    validation_data=val_ds,\n    batch_size=BATCH_SIZE,\n    epochs=EPOCHS,\n    callbacks=callbacks,\n    verbose=1,\n)')12 frames/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)   2115             magic_arg_s = self.var_expand(line, stack_depth)   2116             with self.builtin_trap:-> 2117                 result = fn(magic_arg_s, cell)   2118             return result   2119 <decorator-gen-53> in time(self, line, cell, local_ns)/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)    186     # but it's overkill for just that one bit of state.    187     def magic_deco(arg):--> 188         call = lambda f, *a, **k: f(*a, **k)    189     190         if callable(arg):/usr/local/lib/python3.7/dist-packages/IPython/core/magics/execution.py in time(self, line, cell, local_ns)   1191         else:   1192             st = clock2()-> 1193             exec(code, glob, local_ns)   1194             end = clock2()   1195             out = None<timed exec> in <module>()/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)   1176                 _r=1):   1177               callbacks.on_train_batch_begin(step)-> 1178               tmp_logs = self.train_function(iterator)   1179               if data_handler.should_sync:   1180                 context.async_wait()/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)    887     888       with OptionalXlaContext(self._jit_compile):--> 889         result = self._call(*args, **kwds)    890     891       new_tracing_count = self.experimental_get_tracing_count()/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)    931       # This is the first call of __call__, so we have to initialize.    932       initializers = []--> 933       self._initialize(args, kwds, add_initializers_to=initializers)    934     finally:    935       # At this point we know that the initialization is complete (or less/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)    762     self._concrete_stateful_fn = (    763         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access--> 764             *args, **kwds))    765     766     def invalid_creator_scope(*unused_args, **unused_kwds):/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)   3048       args, kwargs = None, None   3049     with self._lock:-> 3050       graph_function, _ = self._maybe_define_function(args, kwargs)   3051     return graph_function   3052 /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)   3442    3443           self._function_cache.missed.add(call_context_key)-> 3444           graph_function = self._create_graph_function(args, kwargs)   3445           self._function_cache.primary[cache_key] = graph_function   3446 /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)   3287             arg_names=arg_names,   3288             override_flat_arg_shapes=override_flat_arg_shapes,-> 3289             capture_by_value=self._capture_by_value),   3290         self._function_attributes,   3291         function_spec=self.function_spec,/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)    997         _, original_func = tf_decorator.unwrap(python_func)    998 --> 999       func_outputs = python_func(*func_args, **func_kwargs)   1000    1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)    670         # the function a weak reference to itself to avoid a reference cycle.    671         with OptionalXlaContext(compile_with_xla):--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)    673         return out    674 /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)    984           except Exception as e:  # pylint:disable=broad-except    985             if hasattr(e, "ag_error_metadata"):--> 986               raise e.ag_error_metadata.to_exception(e)    987             else:    988               raiseTypeError: in user code:    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:850 train_function  *        return step_function(self, iterator)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:840 step_function  **        outputs = model.distribute_strategy.run(run_step, args=(data,))    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1285 run        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica        return self._call_for_each_replica(fn, args, kwargs)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica        return fn(*args, **kwargs)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:833 run_step  **        outputs = model.train_step(data)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 train_step        self.compiled_metrics.update_state(y, y_pred, sample_weight)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/compile_utils.py:460 update_state        metric_obj.update_state(y_t, y_p, sample_weight=mask)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:86 decorated        update_op = update_state_fn(*args, **kwargs)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:177 update_state_fn        return ag_update_state(*args, **kwargs)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:659 update_state  **        [y_true, y_pred], sample_weight)    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:546 ragged_assert_compatible_and_get_flat_values        raise TypeError('One of the inputs does not have acceptable types.')    TypeError: One of the inputs does not have acceptable types.

这种方法在普通的二分类中有效,但在多标签分类中却不行。关于这个错误或一般方法,任何帮助我都会非常感激的。


回答:

问题在于您使用的是TFAutoModelForSequenceClassification,即ForSequenceClassification,如果您查看它的摘要,您会发现它返回一个Dense输出,因此它并不是您想要的编码器。

encoder = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased')encoder.summary()'''Result:All model checkpoint layers were used when initializing TFBertForSequenceClassification.Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier']You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.Model: "tf_bert_for_sequence_classification_2"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================bert (TFBertMainLayer)       multiple                  109482240 _________________________________________________________________dropout_187 (Dropout)        multiple                  0         _________________________________________________________________classifier (Dense)           multiple                  1538      =================================================================Total params: 109,483,778Trainable params: 109,483,778Non-trainable params: 0'''

但您想将其作为encoder使用,因此您需要像这样做:

from transformers import TFBertModelencoder = TFBertModel.from_pretrained('bert-base-uncased')model = create_model(encoder)encoder.summary()'''Model: "tf_bert_model_2"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================bert (TFBertMainLayer)       multiple                  109482240 =================================================================Total params: 109,482,240Trainable params: 109,482,240Non-trainable params: 0'''

如您所见,现在您的编码器是Bert的输出。现在create_model中的下行代码有意义了,但它会因为create_model函数中的下行代码而报错:

output = encoder({'input_ids': input_ids,                       'attention_mask': input_attention_mask})[0]

这是因为索引0处的输出形状为(batch_size, token_length, embedding),但我们想要的是[CLS]标记的值,它应该是(batch_size, embedding),位于索引1处,因此我们需要更新为下行代码:

output = encoder({'input_ids': input_ids,                       'attention_mask': input_attention_mask})[1]

此外,目前您将input_shape固定为512,但我们应该将其指定为None,以便可以处理可变长度的输入,如下所示:

input_ids = tf.keras.Input(shape=(None,), ragged=False, dtype=tf.int32, name='input_ids')    input_attention_mask = tf.keras.Input(shape=(None,), ragged=False, dtype=tf.int32, name='attention_mask')

在进行所有这些更改后,以下是样本运行的结果。

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')encoder = TFBertModel.from_pretrained('bert-base-uncased')model = create_model(encoder)inputs = tokenizer('hello world', return_tensors='tf')model.predict((inputs['input_ids'], inputs['attention_mask']))'''Results:array([[0.7867866 , 0.65974414, 0.45628983]], dtype=float32)'''

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注