Monday, January 23, 2023
HomeData ScienceTransformer Fashions For Customized Textual content Classification By way of High quality-Tuning...

Transformer Fashions For Customized Textual content Classification By way of High quality-Tuning | by Skanda Vivek | Jan, 2023


High quality-Tuned SMS Spam Classifier Mannequin Output | Skanda Vivek

The DistiBERT model was launched by the oldsters at Hugging Face, as a less expensive, quicker various to giant transformer fashions like BERT. It was initially launched in a weblog submit. The best way this mannequin works — is by utilizing a teacher-student coaching strategy, the place the “pupil” mannequin is a smaller model of the instructor mannequin. Then, as a substitute of coaching the coed on the final word goal outputs (principally one-hot encodings of the label class), the mannequin is skilled on the softmax outputs of the unique “instructor mannequin”. It is a brilliantly easy thought, and the authors present that:

“it’s doable to cut back the scale of a BERT mannequin by 40%, whereas retaining 97% of its language understanding capabilities and being 60% quicker.”

On this instance, I take advantage of the SMS spam assortment dataset within the UCI Machine Studying Repository and construct a classifier that detects SPAM vs HAM (not SPAM). The info incorporates 5,574 rows of SMS texts which are labeled as SPAM or HAM.

First, I make practice and validation information from the unique csv and use the load_dataset perform from the Hugging Face datasets library.

from datasets import load_dataset
import pandas as pd

df=pd.read_csv(‘/content material/spam.csv’, encoding = “ISO-8859–1”)
df=df[['v1','v2']]
df.columns=['label','text']
df.loc[df['label']=='ham','label']=0
df.loc[df['label']=='spam','label']=1
df2[:4179].reset_index(drop=True).to_csv('df_train.csv',index=False)
df2[4179:].reset_index(drop=True).to_csv('df_test.csv',index=False)

dataset = load_dataset('csv', data_files={'practice': '/content material/df_train.csv',
'check': '/content material/df_test.csv'})

The subsequent step is to load within the DistilBERT tokenizer to preprocess the textual content information.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(“distilbert-base-uncased”)

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True,padding=True)

tokenized_data = dataset.map(preprocess_function, batched=True)

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Previous to coaching, it’s essential map IDs to labels. After this, it’s essential specify the coaching hyperparameters, name coach.practice() to start fine-tuning, and push the skilled mannequin to the Hugging Face hub utilizing coach.push_to_hub().

id2label = {0: “HAM”, 1: “SPAM”}
label2id = {“HAM”: 0, “SPAM”: 1}

from transformers import AutoModelForSequenceClassification, TrainingArguments, Coach

mannequin = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id

training_args = TrainingArguments(
output_dir="spam-classifier",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)

coach = Coach(
mannequin=mannequin,
args=training_args,
train_dataset=tokenized_data["train"],
eval_dataset=tokenized_data["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)

coach.practice()

coach.push_to_hub()

That’s it! As you’ll be able to see from the Hugging Face hub, the mannequin accuracy is fairly good (0.9885)!

Inference can be comparatively easy. You possibly can see the output by means of operating python scripts as beneath:

textual content = “E mail AlertFrom: Ash Kopatz. Click on right here to get a free prescription refill!”

from transformers import pipeline

classifier = pipeline("sentiment-analysis", mannequin="skandavivek2/spam-classifier")
classifier(textual content)

Pattern High quality-Tuned SMS Spam Classifier Mannequin Output | Skanda Vivek

Or run on the Hugging Face hub:

And that’s it! Hugging Face makes it very simple and accessible to adapt cutting-edge transformer fashions to customized language duties so long as you’ve the information!

Right here is the GitHub hyperlink to the code:

When you preferred this weblog, try my different weblog on fine-tuning Transformers for Query Answering!

References:

  1. https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset
  2. Dua, D. and Graff, C. (2019). UCI Machine Studying Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: College of California, College of Data and Pc Science.
  3. Almeida, T.A., Gómez Hidalgo, J.M., Yamakami, A. Contributions to the Examine of SMS Spam Filtering: New Assortment and Outcomes. Proceedings of the 2011 ACM Symposium on Doc Engineering (DOCENG’11), Mountain View, CA, USA, 2011.
  4. https://huggingface.co/docs/transformers/coaching
RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

- Advertisment -
Google search engine

Most Popular

Recent Comments