1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| from trax.supervised import training
def train_model(model, data_generator, batch_size=32, max_length=64, lines=lines, eval_lines=eval_lines, n_steps=1, output_dir='model/'): """Function that trains the model
Args: model (trax.layers.combinators.Serial): GRU model. data_generator (function): Data generator function. batch_size (int, optional): Number of lines per batch. Defaults to 32. max_length (int, optional): Maximum length allowed for a line to be processed. Defaults to 64. lines (list, optional): List of lines to use for training. Defaults to lines. eval_lines (list, optional): List of lines to use for evaluation. Defaults to eval_lines. n_steps (int, optional): Number of steps to train. Defaults to 1. output_dir (str, optional): Relative path of directory to save model. Defaults to "model/".
Returns: trax.supervised.training.Loop: Training loop for the model. """ bare_train_generator = data_generator(batch_size, max_length, data_lines=lines)
infinite_train_generator = itertools.cycle(bare_train_generator) bare_eval_generator = data_generator(batch_size, max_length, data_lines=eval_lines) infinite_eval_generator = itertools.cycle(bare_eval_generator) train_task = training.TrainTask( labeled_data=infinite_train_generator, loss_layer=tl.CrossEntropyLoss(), optimizer=trax.optimizers.Adam(0.0005) )
eval_task = training.EvalTask( labeled_data=infinite_eval_generator, metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], n_eval_batches=3 ) training_loop = training.Loop(model, train_task, eval_tasks=eval_task, output_dir=output_dir)
training_loop.run(n_steps=n_steps) return training_loop
|