标签:Dimension opp ecif __name__ gen rac AC github cut
numpy.argmax(a, axis=None, out=None)[source]Returns the indices of the maximum values along an axis.
| Parameters: |
a : array_like
axis : int, optional
out : array, optional
|
|---|---|
| Returns: |
index_array : ndarray of ints
|
See also
amaxunravel_indexNotes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1
在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
if __name__ == "__main__":
width, height = 32, 32
X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
print("sample data:")
print(trainX[0])
print(trainY[0])
print(testX[-1])
print(testY[-1])
model = get_model(width, height, classes=100)
filename = ‘cnn_handwrite-acc0.8.tflearn‘
# try to load model and resume training
#try:
# model.load(filename)
# print("Model loaded OK. Resume training!")
#except:
# pass
# Initialize our callback with desired accuracy threshold.
early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id=‘cnn_handwrite‘)
except StopIteration as e:
print("OK, stop iterate!Good!")
model.save(filename)
# predict all data and calculate confusion_matrix
model.load(filename)
pro_arr =model.predict(X)
predict_labels = np.argmax(pro_arr, axis=1)
print(classification_report(org_labels, predict_labels))
print(confusion_matrix(org_labels, predict_labels))
标签:Dimension opp ecif __name__ gen rac AC github cut
原文地址:https://www.cnblogs.com/bonelee/p/8976380.html