Using custom preprocessing functionsΒΆ

Preprocessing functions can be used with ml4ir in the data loading pipeline. Below we demonstrate how to define a custom preprocessing function and use it to load the data to train a RelevanceModel.

In this example, we define a preprocessing function to split a string into tokens and pad to max length.

@tf.function
def split_and_pad_string(feature_tensor, split_char=",", max_length=20):
    tokens = tf.strings.split(feature_tensor, sep=split_char).to_tensor()
    
    padded_tokens = tf.image.pad_to_bounding_box(
        tf.expand_dims(tokens[:, :max_length], axis=-1),
        offset_height=0,
        offset_width=0,
        target_height=1,
        target_width=max_length,
    )    
    padded_tokens = tf.squeeze(padded_tokens, axis=-1)
    
    return padded_tokens

Define the preprocessing function in the FeatureConfig YAML:

- name: query_text
  node_name: query_text
  trainable: true
  dtype: string
  log_at_inference: true
  preprocessing_info:
    - fn: split_and_pad_string
      args:
        split_char: " "
        max_length: 20
  serving_info:
    name: query_text
    required: true

Finally, use the custom split and pad prepreprocessing function to load a RelevanceDataset by passing custom functions as the preprocessing_keys_to_fns argument:

custom_preprocessing_fns = {
    "split_and_pad_string": split_and_pad_string
}

relevance_dataset = RelevanceDataset(
        data_dir=CSV_DATA_DIR,
        data_format=DataFormatKey.CSV,
        feature_config=feature_config,
        tfrecord_type=TFRecordTypeKey.EXAMPLE,
        batch_size=128,
        preprocessing_keys_to_fns=custom_preprocessing_fns,
        file_io=file_io,
        logger=logger
    )

Optionally, we can save preprocessing functions in the SavedModel as part of the serving signature as well. This requires that the preprocessing function is a tf.function that can be serialized as a tensorflow layer.

relevance_model.save(
    models_dir=MODEL_DIR,
    preprocessing_keys_to_fns=custom_preprocessing_fns,
    required_fields_only=True)