package org.grouplens.lenskit.eval.data.crossfold;

import com.google.common.io.Files;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongListIterator;
import java.io.File;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.DAOFactory;
import org.grouplens.lenskit.data.dao.DataAccessObject;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.data.sql.BasicSQLStatementFactory;
import org.grouplens.lenskit.data.sql.JDBCRatingDAO;
import org.grouplens.lenskit.data.sql.JDBCUtils;
import org.grouplens.lenskit.eval.PreparationContext;
import org.grouplens.lenskit.eval.PreparationException;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataSet;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sqlite.SQLiteConfig;

/* loaded from: input_file:org/grouplens/lenskit/eval/data/crossfold/DBCrossfoldTTDataSet.class */
public class DBCrossfoldTTDataSet implements TTDataSet {
    private CrossfoldManager manager;
    private final int foldNumber;
    private final String stampName;
    private final String dbName;
    private File dbFile;
    private GenericTTDataSet dataset;
    private Logger logger = LoggerFactory.getLogger(DBCrossfoldTTDataSet.class);
    private boolean useTimestamp = true;

    public DBCrossfoldTTDataSet(CrossfoldManager crossfoldManager, int i) {
        this.manager = crossfoldManager;
        this.foldNumber = i;
        this.stampName = String.format("data.%d.stamp", Integer.valueOf(i));
        this.dbName = String.format("data.%d.db", Integer.valueOf(i));
    }

    protected String getDSN() {
        return "jdbc:sqlite:" + this.dbFile.getAbsolutePath();
    }

    @Override // org.grouplens.lenskit.eval.Preparable
    public long lastUpdated(PreparationContext preparationContext) {
        File file = new File(this.manager.cacheDir(preparationContext), this.stampName);
        if (file.exists()) {
            return file.lastModified();
        }
        return -1L;
    }

    @Override // org.grouplens.lenskit.eval.Preparable
    public void prepare(PreparationContext preparationContext) throws PreparationException {
        preparationContext.prepare(this.manager);
        try {
            Class.forName("org.sqlite.JDBC");
            this.dbFile = new File(this.manager.cacheDir(preparationContext), this.dbName);
            if (preparationContext.isUnconditional() || lastUpdated(preparationContext) < this.manager.lastUpdated(preparationContext)) {
                importDataSet(preparationContext);
            } else {
                this.logger.debug("Data set {} up to date", this);
            }
            this.dataset = new GenericTTDataSet(getName(), makeDAOFactory("train"), makeDAOFactory("test"));
            preparationContext.prepare(this.dataset);
        } catch (ClassNotFoundException e) {
            throw new PreparationException("Cannot load JDBC driver", e);
        }
    }

    private DAOFactory makeDAOFactory(String str) {
        BasicSQLStatementFactory basicSQLStatementFactory = new BasicSQLStatementFactory();
        basicSQLStatementFactory.setTableName(str);
        basicSQLStatementFactory.setTimestampColumn(this.useTimestamp ? "timestamp" : null);
        SQLiteConfig sQLiteConfig = new SQLiteConfig();
        sQLiteConfig.setReadOnly(true);
        return new JDBCRatingDAO.Factory(getDSN(), basicSQLStatementFactory, sQLiteConfig.toProperties());
    }

    void importDataSet(PreparationContext preparationContext) throws PreparationException {
        this.logger.debug("Importing fold {} for {}", Integer.valueOf(this.foldNumber), this.manager.getSource().getName());
        this.dbFile.delete();
        DataAccessObject create = this.manager.getSource().getDAOFactory().create();
        try {
            try {
                Connection connection = DriverManager.getConnection(getDSN());
                try {
                    writePartitionData(preparationContext, create, connection);
                    connection.close();
                    try {
                        Files.touch(new File(this.manager.cacheDir(preparationContext), this.stampName));
                    } catch (IOException e) {
                        throw new PreparationException("Failed to update stamp", e);
                    }
                } catch (Throwable th) {
                    connection.close();
                    throw th;
                }
            } finally {
                create.close();
            }
        } catch (SQLException e2) {
            throw new PreparationException("Failed to write database", e2);
        }
    }

    private String makeInsertSQL(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append("INSERT INTO ");
        sb.append(str);
        sb.append(" (id, user, item, rating");
        if (this.useTimestamp) {
            sb.append(", timestamp");
        }
        sb.append(") VALUES (?, ?, ?, ?");
        if (this.useTimestamp) {
            sb.append(", ?");
        }
        sb.append(")");
        this.logger.debug("Insert SQL: {}", sb);
        return sb.toString();
    }

    private String makeCreateSQL(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append("CREATE TABLE ");
        sb.append(str);
        sb.append(" (id INTEGER PRIMARY KEY, user INTEGER NOT NULl, item INTEGER NOT NULL, rating REAL");
        if (this.useTimestamp) {
            sb.append(", timestamp INTEGER NULL");
        }
        sb.append(")");
        return sb.toString();
    }

    private Long2ObjectMap<List<Rating>> initUserMap(PreparationContext preparationContext) {
        LongList foldUsers = this.manager.getFoldUsers(preparationContext, this.foldNumber);
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap(foldUsers.size());
        LongListIterator it = foldUsers.iterator();
        while (it.hasNext()) {
            long2ObjectOpenHashMap.put(it.nextLong(), new ArrayList());
        }
        return long2ObjectOpenHashMap;
    }

    private void bindRating(PreparedStatement preparedStatement, Rating rating) throws SQLException {
        preparedStatement.setLong(1, rating.getId());
        preparedStatement.setLong(2, rating.getUserId());
        preparedStatement.setLong(3, rating.getItemId());
        Preference preference = rating.getPreference();
        if (preference == null) {
            preparedStatement.setNull(4, 7);
        } else {
            preparedStatement.setDouble(4, preference.getValue());
        }
        if (this.useTimestamp) {
            long timestamp = rating.getTimestamp();
            if (timestamp >= 0) {
                preparedStatement.setLong(5, timestamp);
            } else {
                preparedStatement.setNull(5, 4);
            }
        }
    }

    /* JADX WARN: Finally extract failed */
    private void writePartitionData(PreparationContext preparationContext, DataAccessObject dataAccessObject, Connection connection) throws SQLException {
        Holdout holdout = this.manager.getHoldout();
        PreparedStatement preparedStatement = null;
        PreparedStatement preparedStatement2 = null;
        connection.setAutoCommit(false);
        try {
            JDBCUtils.execute(connection, makeCreateSQL("train"));
            JDBCUtils.execute(connection, makeCreateSQL("test"));
            preparedStatement = connection.prepareStatement(makeInsertSQL("train"));
            preparedStatement2 = connection.prepareStatement(makeInsertSQL("test"));
            Long2ObjectMap<List<Rating>> initUserMap = initUserMap(preparationContext);
            Cursor events = dataAccessObject.getEvents(Rating.class);
            try {
                this.logger.debug("Writing training data");
                for (Rating rating : events.fast()) {
                    long userId = rating.getUserId();
                    if (initUserMap.containsKey(userId)) {
                        ((List) initUserMap.get(userId)).add(rating.clone());
                    } else {
                        bindRating(preparedStatement, rating);
                        preparedStatement.execute();
                    }
                }
                events.close();
                this.logger.debug("Splitting and writing query users");
                for (List<Rating> list : initUserMap.values()) {
                    int partition = holdout.partition(list);
                    ListIterator<Rating> listIterator = list.listIterator();
                    while (listIterator.hasNext()) {
                        int nextIndex = listIterator.nextIndex();
                        Rating next = listIterator.next();
                        PreparedStatement preparedStatement3 = nextIndex < partition ? preparedStatement : preparedStatement2;
                        bindRating(preparedStatement3, next);
                        preparedStatement3.execute();
                    }
                }
                JDBCUtils.close(new Statement[]{preparedStatement, preparedStatement2});
                this.logger.debug("Committing data");
                connection.commit();
                connection.setAutoCommit(true);
                this.logger.debug("Indexing tables");
                JDBCUtils.execute(connection, "CREATE INDEX train_user_idx ON train (user);");
                JDBCUtils.execute(connection, "CREATE INDEX train_item_idx ON train (item);");
                if (this.useTimestamp) {
                    JDBCUtils.execute(connection, "CREATE INDEX train_timestamp_idx ON train (timestamp);");
                }
                JDBCUtils.execute(connection, "CREATE INDEX test_user_idx ON test (user);");
                JDBCUtils.execute(connection, "CREATE INDEX test_item_idx ON test (item);");
                JDBCUtils.execute(connection, "ANALYZE;");
            } catch (Throwable th) {
                events.close();
                throw th;
            }
        } catch (Throwable th2) {
            JDBCUtils.close(new Statement[]{preparedStatement, preparedStatement2});
            throw th2;
        }
    }

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataSet
    public String getName() {
        return String.format("%s:%d", this.manager.getSource().getName(), Integer.valueOf(this.foldNumber));
    }

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataSet
    public Map<String, Object> getAttributes() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("DataSet", this.manager.getSource().getName());
        linkedHashMap.put("Segment", Integer.valueOf(this.foldNumber));
        return linkedHashMap;
    }

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataSet
    public void release() {
        this.dbFile = null;
    }

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataSet
    public DAOFactory getTrainFactory() {
        if (this.dataset == null) {
            throw new IllegalStateException("Data set not prepared");
        }
        return this.dataset.getTrainFactory();
    }

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataSet
    public DAOFactory getTestFactory() {
        if (this.dataset == null) {
            throw new IllegalStateException("Data set not prepared");
        }
        return this.dataset.getTestFactory();
    }

    public String toString() {
        return String.format("{Crossfold %s:%d}", this.manager.getSource(), Integer.valueOf(this.foldNumber));
    }
}
