trainer.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """" Train and evaluate models. """
  2. import data_pipeline
  3. import logging
  4. from model import MovierecModel
  5. import os
  6. DEFAULT_PARAMS = {
  7. # model:
  8. # "num_users": obtained from data generator,
  9. # "num_items": obtained from data generator,
  10. "layers_sizes": [64, 32, 16, 8],
  11. "layers_l2reg": [0, 0, 0, 0],
  12. # training:
  13. "optimizer": "adam",
  14. "lr": 0.001,
  15. "beta_1": 0.9,
  16. "beta_2": 0.999,
  17. "batch_size": 100,
  18. "batch_size_eval": 200,
  19. "num_negs_per_pos": 9,
  20. "num_negs_per_pos_eval": 99,
  21. "k": 5,
  22. "epochs": 20,
  23. }
  24. def train(model_name, dataset_name, data_dir, output_dir, params=DEFAULT_PARAMS, verbose=1):
  25. """
  26. Create dataset train and validation generators, create and compile model and train the model.
  27. Parameters
  28. ----------
  29. model_name : str
  30. Name of the model to train (used in model object and model files to save).
  31. dataset_name : str
  32. Movielens dataset name. Must be one of MOVIELENS_DATASET_NAMES.
  33. data_dir : str or os.path
  34. Dataset directory to read ratings data from. The file to read from the directory will be:
  35. data_dir/dataset_name/data_pipeline.RATINGS_FILE_NAME[dataset_name].
  36. output_dir : str or `os.path`
  37. Output file directory to save model files.
  38. params : dict of param names (str) to values (any type)
  39. Dictionary of model hyper parameters. Default: `DEFAULT_PARAMS`
  40. verbose : int
  41. Verbosity mode.
  42. """
  43. # TODO: cache data?
  44. train_df, validation_df, test_df = data_pipeline.load_ratings_train_test_sets(dataset_name, data_dir)
  45. # Train and validation data generators.
  46. train_data_generator = data_pipeline.MovieLensDataGenerator(
  47. dataset_name,
  48. train_df,
  49. params["batch_size"],
  50. params["num_negs_per_pos"],
  51. extra_data_df=None, # Don't use any validation/test info.
  52. # Some negatives in train might be positives in val/test.
  53. shuffle=True)
  54. validation_data_generator = data_pipeline.MovieLensDataGenerator(
  55. dataset_name,
  56. validation_df,
  57. params["batch_size_eval"],
  58. params["num_negs_per_pos_eval"],
  59. extra_data_df=train_df, # Use train to avoid positives from train in validation.
  60. shuffle=False)
  61. # update users and items (obtained from data) and create model:
  62. params["num_users"] = train_data_generator.num_users
  63. params["num_items"] = train_data_generator.num_items
  64. movierec_model = MovierecModel(params, model_name, output_dir, verbose)
  65. movierec_model.log_summary()
  66. movierec_model.fit_generator(train_data_generator, validation_data_generator, params["epochs"])
  67. movierec_model.save()
  68. if __name__ == '__main__':
  69. import argparse
  70. parser = argparse.ArgumentParser(description='Train a Keras model.')
  71. parser.add_argument('-m', '--model-name', type=str, required=True,
  72. help='Model name (to save output files).')
  73. parser.add_argument('-n', '--dataset-name', type=str, required=True,
  74. help='Movielens dataset name.')
  75. parser.add_argument('-d', '--data-dir', type=str, default='data/',
  76. help='Dataset directory to read ratings data from')
  77. parser.add_argument('-o', '--output-dir', type=str, default='models',
  78. help='Output dir to save model files.')
  79. parser.add_argument('-l', '--log-level', type=str, default='INFO',
  80. help='Log level (default: INFO).')
  81. # TODO Allow `params` as argument
  82. args = parser.parse_args()
  83. logging.getLogger().setLevel(logging.getLevelName(args.log_level))
  84. logging.info("Starting training with params: {}".format(DEFAULT_PARAMS))
  85. train(args.model_name, args.dataset_name, args.data_dir, args.output_dir, DEFAULT_PARAMS, logging.getLogger().level)