Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 90 additions & 18 deletions utils/src/main/java/com/cloud/utils/ssh/SshHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,26 @@
package com.cloud.utils.ssh;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;

import org.apache.log4j.Logger;

import com.trilead.ssh2.ChannelCondition;

import com.trilead.ssh2.Connection;
import com.trilead.ssh2.Session;
import com.cloud.utils.Pair;

public class SshHelper {
private static final int DEFAULT_CONNECT_TIMEOUT = 180000;
private static final int DEFAULT_KEX_TIMEOUT = 60000;

/**
* Waiting time to check if the SSH session was successfully opened. This value (of 1000
* milliseconds) represents one (1) second.
*/
private static final long WAITING_OPEN_SSH_SESSION = 1000;

private static final Logger s_logger = Logger.getLogger(SshHelper.class);

public static Pair<Boolean, String> sshExecute(String host, int port, String user, File pemKeyFile, String password, String command) throws Exception {
Expand All @@ -40,19 +48,19 @@ public static Pair<Boolean, String> sshExecute(String host, int port, String use
}

public static void scpTo(String host, int port, String user, File pemKeyFile, String password, String remoteTargetDirectory, String localFile, String fileMode)
throws Exception {
throws Exception {

scpTo(host, port, user, pemKeyFile, password, remoteTargetDirectory, localFile, fileMode, DEFAULT_CONNECT_TIMEOUT, DEFAULT_KEX_TIMEOUT);
}

public static void scpTo(String host, int port, String user, File pemKeyFile, String password, String remoteTargetDirectory, byte[] data, String remoteFileName,
String fileMode) throws Exception {
String fileMode) throws Exception {

scpTo(host, port, user, pemKeyFile, password, remoteTargetDirectory, data, remoteFileName, fileMode, DEFAULT_CONNECT_TIMEOUT, DEFAULT_KEX_TIMEOUT);
}

public static void scpTo(String host, int port, String user, File pemKeyFile, String password, String remoteTargetDirectory, String localFile, String fileMode,
int connectTimeoutInMs, int kexTimeoutInMs) throws Exception {
int connectTimeoutInMs, int kexTimeoutInMs) throws Exception {

com.trilead.ssh2.Connection conn = null;
com.trilead.ssh2.SCPClient scpClient = null;
Expand Down Expand Up @@ -88,7 +96,7 @@ public static void scpTo(String host, int port, String user, File pemKeyFile, St
}

public static void scpTo(String host, int port, String user, File pemKeyFile, String password, String remoteTargetDirectory, byte[] data, String remoteFileName,
String fileMode, int connectTimeoutInMs, int kexTimeoutInMs) throws Exception {
String fileMode, int connectTimeoutInMs, int kexTimeoutInMs) throws Exception {

com.trilead.ssh2.Connection conn = null;
com.trilead.ssh2.SCPClient scpClient = null;
Expand Down Expand Up @@ -123,7 +131,8 @@ public static void scpTo(String host, int port, String user, File pemKeyFile, St
}

public static Pair<Boolean, String> sshExecute(String host, int port, String user, File pemKeyFile, String password, String command, int connectTimeoutInMs,
int kexTimeoutInMs, int waitResultTimeoutInMs) throws Exception {
int kexTimeoutInMs,
int waitResultTimeoutInMs) throws Exception {

com.trilead.ssh2.Connection conn = null;
com.trilead.ssh2.Session sess = null;
Expand All @@ -144,7 +153,7 @@ public static Pair<Boolean, String> sshExecute(String host, int port, String use
throw new Exception(msg);
}
}
sess = conn.openSession();
sess = openConnectionSession(conn);

sess.execCommand(command);

Expand All @@ -156,22 +165,22 @@ public static Pair<Boolean, String> sshExecute(String host, int port, String use

int currentReadBytes = 0;
while (true) {
throwSshExceptionIfStdoutOrStdeerIsNull(stdout, stderr);

if ((stdout.available() == 0) && (stderr.available() == 0)) {
int conditions =
sess.waitForCondition(ChannelCondition.STDOUT_DATA | ChannelCondition.STDERR_DATA | ChannelCondition.EOF | ChannelCondition.EXIT_STATUS,
int conditions = sess.waitForCondition(ChannelCondition.STDOUT_DATA | ChannelCondition.STDERR_DATA | ChannelCondition.EOF | ChannelCondition.EXIT_STATUS,
waitResultTimeoutInMs);

if ((conditions & ChannelCondition.TIMEOUT) != 0) {
String msg = "Timed out in waiting SSH execution result";
s_logger.error(msg);
throw new Exception(msg);
}
throwSshExceptionIfConditionsTimeout(conditions);

if ((conditions & ChannelCondition.EXIT_STATUS) != 0) {
if ((conditions & (ChannelCondition.STDOUT_DATA | ChannelCondition.STDERR_DATA)) == 0) {
break;
}
break;
}

if (canEndTheSshConnection(waitResultTimeoutInMs, sess, conditions)) {
break;
}

}

while (stdout.available() > 0) {
Expand All @@ -189,11 +198,12 @@ public static Pair<Boolean, String> sshExecute(String host, int port, String use

if (sess.getExitStatus() == null) {
//Exit status is NOT available. Returning failure result.
s_logger.error(String.format("SSH execution of command %s has no exit status set. Result output: %s", command, result));
return new Pair<Boolean, String>(false, result);
}

if (sess.getExitStatus() != null && sess.getExitStatus().intValue() != 0) {
s_logger.error("SSH execution of command " + command + " has an error status code in return. result output: " + result);
s_logger.error(String.format("SSH execution of command %s has an error status code in return. Result output: %s", command, result));
return new Pair<Boolean, String>(false, result);
}

Expand All @@ -206,4 +216,66 @@ public static Pair<Boolean, String> sshExecute(String host, int port, String use
conn.close();
}
}

/**
* It gets a {@link Session} from the given {@link Connection}; then, it waits
* {@value #WAITING_OPEN_SSH_SESSION} milliseconds before returning the session, given a time to
* ensure that the connection is open before proceeding the execution.
*/
protected static Session openConnectionSession(Connection conn) throws IOException, InterruptedException {
Session sess = conn.openSession();
Thread.sleep(WAITING_OPEN_SSH_SESSION);
return sess;
}

/**
* Handles the SSH connection in case of timeout or exit. If the session ends with a timeout
* condition, it throws an exception; if the channel reaches an end of file condition, but it
* does not have an exit status, it returns true to break the loop; otherwise, it returns
* false.
*/
protected static boolean canEndTheSshConnection(int waitResultTimeoutInMs, com.trilead.ssh2.Session sess, int conditions) throws SshException {
if (isChannelConditionEof(conditions)) {
int newConditions = sess.waitForCondition(ChannelCondition.EXIT_STATUS, waitResultTimeoutInMs);
throwSshExceptionIfConditionsTimeout(newConditions);
if ((newConditions & ChannelCondition.EXIT_STATUS) != 0) {
return true;
}
}
return false;
}

/**
* It throws a {@link SshException} if the channel condition is {@link ChannelCondition#TIMEOUT}
*/
protected static void throwSshExceptionIfConditionsTimeout(int conditions) throws SshException {
if ((conditions & ChannelCondition.TIMEOUT) != 0) {
String msg = "Timed out in waiting for SSH execution exit status";
s_logger.error(msg);
throw new SshException(msg);
}
}

/**
* Checks if the channel condition mask is of {@link ChannelCondition#EOF} and not
* {@link ChannelCondition#STDERR_DATA} or {@link ChannelCondition#STDOUT_DATA}.
*/
protected static boolean isChannelConditionEof(int conditions) {
if ((conditions & ChannelCondition.EOF) != 0) {
return true;
}
return false;
}

/**
* Checks if the SSH session {@link com.trilead.ssh2.Session#getStdout()} or
* {@link com.trilead.ssh2.Session#getStderr()} is null.
*/
protected static void throwSshExceptionIfStdoutOrStdeerIsNull(InputStream stdout, InputStream stderr) throws SshException {
if (stdout == null || stderr == null) {
String msg = "Stdout or Stderr of SSH session is null";
s_logger.error(msg);
throw new SshException(msg);
}
}
}
151 changes: 151 additions & 0 deletions utils/src/test/java/com/cloud/utils/ssh/SshHelperTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//

package com.cloud.utils.ssh;

import java.io.IOException;
import java.io.InputStream;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;

import com.trilead.ssh2.ChannelCondition;
import com.trilead.ssh2.Connection;
import com.trilead.ssh2.Session;

@PrepareForTest({ Thread.class, SshHelper.class })
@RunWith(PowerMockRunner.class)
public class SshHelperTest {

@Test
public void canEndTheSshConnectionTest() throws Exception {
PowerMockito.spy(SshHelper.class);
Session mockedSession = Mockito.mock(Session.class);

PowerMockito.doReturn(true).when(SshHelper.class, "isChannelConditionEof", Mockito.anyInt());
Mockito.when(mockedSession.waitForCondition(ChannelCondition.EXIT_STATUS, 1l)).thenReturn(0);
PowerMockito.doNothing().when(SshHelper.class, "throwSshExceptionIfConditionsTimeout", Mockito.anyInt());

SshHelper.canEndTheSshConnection(1, mockedSession, 0);

PowerMockito.verifyStatic();
SshHelper.isChannelConditionEof(Mockito.anyInt());
SshHelper.throwSshExceptionIfConditionsTimeout(Mockito.anyInt());

Mockito.verify(mockedSession).waitForCondition(ChannelCondition.EXIT_STATUS, 1l);
}

@Test(expected = SshException.class)
public void throwSshExceptionIfConditionsTimeout() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.TIMEOUT);
}

@Test
public void doNotThrowSshExceptionIfConditionsClosed() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.CLOSED);
}

@Test
public void doNotThrowSshExceptionIfConditionsStdout() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.STDOUT_DATA);
}

@Test
public void doNotThrowSshExceptionIfConditionsStderr() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.STDERR_DATA);
}

@Test
public void doNotThrowSshExceptionIfConditionsEof() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.EOF);
}

@Test
public void doNotThrowSshExceptionIfConditionsExitStatus() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.EXIT_STATUS);
}

@Test
public void doNotThrowSshExceptionIfConditionsExitSignal() throws SshException {
SshHelper.throwSshExceptionIfConditionsTimeout(ChannelCondition.EXIT_SIGNAL);
}

@Test
public void isChannelConditionEofTestTimeout() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.TIMEOUT));
}

@Test
public void isChannelConditionEofTestClosed() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.CLOSED));
}

@Test
public void isChannelConditionEofTestStdout() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.STDOUT_DATA));
}

@Test
public void isChannelConditionEofTestStderr() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.STDERR_DATA));
}

@Test
public void isChannelConditionEofTestEof() {
Assert.assertTrue(SshHelper.isChannelConditionEof(ChannelCondition.EOF));
}

@Test
public void isChannelConditionEofTestExitStatus() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.EXIT_STATUS));
}

@Test
public void isChannelConditionEofTestExitSignal() {
Assert.assertFalse(SshHelper.isChannelConditionEof(ChannelCondition.EXIT_SIGNAL));
}

@Test(expected = SshException.class)
public void throwSshExceptionIfStdoutOrStdeerIsNullTestNull() throws SshException {
SshHelper.throwSshExceptionIfStdoutOrStdeerIsNull(null, null);
}

@Test
public void throwSshExceptionIfStdoutOrStdeerIsNullTestNotNull() throws SshException {
InputStream inputStream = Mockito.mock(InputStream.class);
SshHelper.throwSshExceptionIfStdoutOrStdeerIsNull(inputStream, inputStream);
}

@Test
public void openConnectionSessionTest() throws IOException, InterruptedException {
Connection conn = Mockito.mock(Connection.class);
PowerMockito.mockStatic(Thread.class);
SshHelper.openConnectionSession(conn);

Mockito.verify(conn).openSession();

PowerMockito.verifyStatic();
Thread.sleep(Mockito.anyLong());
}
}