Commit 18c7f45b authored by kulcsar's avatar kulcsar
Browse files

further tmix changes in eval, train and main

parent 330101c5
Loading
Loading
Loading
Loading
+17 −16
Original line number Diff line number Diff line
@@ -27,22 +27,23 @@ def train(model, name,train_dataset, test_dataset, seed, batch_size, test_batch_
	"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model.
	
	Params:
	model: model out of models.py
	name: str
	train_dataset: Dataset 
	test_dataset: Dataset
	seed: int
	batch_size: 
	test_batch_size:
	num_epochs: int
	imdb: bool
	mixup: bool
	lambda_value: float
	mixepoch:int
	tmix: bool
	mixlayer: int in {0, 11}
	learning_rate: float
	mlp_leaning_rate:float

	model: model out of models.py ->WordClassificationModel, BertForWordClassification or RobertaForWordClassification
	name: str -> specifies architecture of model (either bert-base-uncased or roberta-base)
	train_dataset: Dataset  -> Train dataset as Torch.Dataset Object (created in preprocess.py)
	test_dataset: Dataset ->Test dataset as Torch.Dataset Object (created in preprocess.py)
	seed: int -> Random seed 
	batch_size: ->batch size for training
	test_batch_size: -> batch size for testing
	num_epochs: int -> number of epochs
	imdb: bool ->whether or not imdb dataset is used
	mixup: bool ->whether or not to use mixup in training 
	lambda_value: float ->if mixup or tmix selected, what lambda value to use
	mixepoch:int -> specifies in what epoch to use mixup
	tmix: bool ->whether or not tmix is used in training (used to differentiate between mixing in training and not mixing in evaluation)
	mixlayer: int in {0, 11} ->what layer to mix in tmix
	learning_rate: float ->learning rate for Bert/Roberta Model, or WordClassificationModel including linear classifier
	mlp_leaning_rate:float ->separate learning rate for multi layer perceptron
	
	
	Returns: Evaluation Results for train and test dataset in Accuracy, F1, Precision and Recall"""
+1 −2
Original line number Diff line number Diff line
@@ -208,8 +208,7 @@ if __name__ == "__main__":
		"-lambda",
		"--lambda_value",
		help="speficies the lambda value for mixup",
		type=float,
		default=0.4)
		type=float)

	parser.add_argument(
		"-mixepoch",